MilesCranmer commited on
Commit
8f3db29
·
unverified ·
2 Parent(s): 0683428 6b6d1bb

Merge pull request #24 from DhananjayAshok/recover

Browse files

Refactor Python into more functions, and Julia into more files.

.gitignore CHANGED
@@ -8,3 +8,7 @@ trials*
8
  **/__pycache__
9
  build
10
  dist
 
 
 
 
 
8
  **/__pycache__
9
  build
10
  dist
11
+ *.vs/*
12
+ *.pyproj
13
+ *.sln
14
+ pysr/.vs/
README.md CHANGED
@@ -69,9 +69,10 @@ pip install pysr
69
 
70
  # Quickstart
71
 
 
72
  ```python
73
  import numpy as np
74
- from pysr import pysr, best, get_hof
75
 
76
  # Dataset
77
  X = 2*np.random.randn(100, 5)
@@ -108,4 +109,3 @@ This is a pandas table, with additional columns:
108
  - `score` - a metric akin to Occam's razor; you should use this to help select the "true" equation.
109
  - `sympy_format` - sympy equation.
110
  - `lambda_format` - a lambda function for that equation, that you can pass values through.
111
-
 
69
 
70
  # Quickstart
71
 
72
+ Here is some demo code (also found in `example.py`)
73
  ```python
74
  import numpy as np
75
+ from pysr import pysr, best
76
 
77
  # Dataset
78
  X = 2*np.random.randn(100, 5)
 
109
  - `score` - a metric akin to Occam's razor; you should use this to help select the "true" equation.
110
  - `sympy_format` - sympy equation.
111
  - `lambda_format` - a lambda function for that equation, that you can pass values through.
 
example.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pysr import pysr, best
3
+
4
+ # Dataset
5
+ X = 2*np.random.randn(100, 5)
6
+ y = 2*np.cos(X[:, 3]) + X[:, 0]**2 - 2
7
+
8
+ # Learn equations
9
+ equations = pysr(X, y, niterations=5,
10
+ binary_operators=["plus", "mult"],
11
+ unary_operators=[
12
+ "cos", "exp", "sin", #Pre-defined library of operators (see https://pysr.readthedocs.io/en/latest/docs/operators/)
13
+ "inv(x) = 1/x"]) # Define your own operator! (Julia syntax)
14
+
15
+ ...# (you can use ctl-c to exit early)
16
+
17
+ print(best(equations))
julia/CheckConstraints.jl ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Check if any binary operator are overly complex
2
+ function flagBinOperatorComplexity(tree::Node, op::Int)::Bool
3
+ if tree.degree == 0
4
+ return false
5
+ elseif tree.degree == 1
6
+ return flagBinOperatorComplexity(tree.l, op)
7
+ else
8
+ if tree.op == op
9
+ overly_complex = (
10
+ ((bin_constraints[op][1] > -1) &&
11
+ (countNodes(tree.l) > bin_constraints[op][1]))
12
+ ||
13
+ ((bin_constraints[op][2] > -1) &&
14
+ (countNodes(tree.r) > bin_constraints[op][2]))
15
+ )
16
+ if overly_complex
17
+ return true
18
+ end
19
+ end
20
+ return (flagBinOperatorComplexity(tree.l, op) || flagBinOperatorComplexity(tree.r, op))
21
+ end
22
+ end
23
+
24
+ # Check if any unary operators are overly complex
25
+ function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
26
+ if tree.degree == 0
27
+ return false
28
+ elseif tree.degree == 1
29
+ if tree.op == op
30
+ overly_complex = (
31
+ (una_constraints[op] > -1) &&
32
+ (countNodes(tree.l) > una_constraints[op])
33
+ )
34
+ if overly_complex
35
+ return true
36
+ end
37
+ end
38
+ return flagUnaOperatorComplexity(tree.l, op)
39
+ else
40
+ return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
41
+ end
42
+ end
julia/ConstantOptimization.jl ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Optim
2
+
3
+ # Proxy function for optimization
4
+ function optFunc(x::Array{Float32, 1}, tree::Node)::Float32
5
+ setConstants(tree, x)
6
+ return scoreFunc(tree)
7
+ end
8
+
9
+ # Use Nelder-Mead to optimize the constants in an equation
10
+ function optimizeConstants(member::PopMember)::PopMember
11
+ nconst = countConstants(member.tree)
12
+ if nconst == 0
13
+ return member
14
+ end
15
+ x0 = getConstants(member.tree)
16
+ f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
17
+ if size(x0)[1] == 1
18
+ algorithm = Optim.Newton
19
+ else
20
+ algorithm = Optim.NelderMead
21
+ end
22
+
23
+ try
24
+ result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
25
+ # Try other initial conditions:
26
+ for i=1:nrestarts
27
+ tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
28
+ if tmpresult.minimum < result.minimum
29
+ result = tmpresult
30
+ end
31
+ end
32
+
33
+ if Optim.converged(result)
34
+ setConstants(member.tree, result.minimizer)
35
+ member.score = convert(Float32, result.minimum)
36
+ member.birth = getTime()
37
+ else
38
+ setConstants(member.tree, x0)
39
+ end
40
+ catch error
41
+ # Fine if optimization encountered domain error, just return x0
42
+ if isa(error, AssertionError)
43
+ setConstants(member.tree, x0)
44
+ else
45
+ throw(error)
46
+ end
47
+ end
48
+ return member
49
+ end
julia/Equation.jl ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define a serialization format for the symbolic equations:
2
+ mutable struct Node
3
+ #Holds operators, variables, constants in a tree
4
+ degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
5
+ val::Union{Float32, Integer} #Either const value, or enumerates variable
6
+ constant::Bool #false if variable
7
+ op::Integer #enumerates operator (separately for degree=1,2)
8
+ l::Union{Node, Nothing}
9
+ r::Union{Node, Nothing}
10
+
11
+ Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
12
+ Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
13
+ Node(op::Integer, l::Node) = new(1, 0.0f0, false, op, l, nothing)
14
+ Node(op::Integer, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
15
+ Node(op::Integer, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
16
+
17
+ #Allow to pass the leaf value without additional node call:
18
+ Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
19
+ Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
20
+ Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
21
+ end
22
+
23
+ # Copy an equation (faster than deepcopy)
24
+ function copyNode(tree::Node)::Node
25
+ if tree.degree == 0
26
+ return Node(tree.val)
27
+ elseif tree.degree == 1
28
+ return Node(tree.op, copyNode(tree.l))
29
+ else
30
+ return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
31
+ end
32
+ end
33
+
34
+ # Count the operators, constants, variables in an equation
35
+ function countNodes(tree::Node)::Integer
36
+ if tree.degree == 0
37
+ return 1
38
+ elseif tree.degree == 1
39
+ return 1 + countNodes(tree.l)
40
+ else
41
+ return 1 + countNodes(tree.l) + countNodes(tree.r)
42
+ end
43
+ end
44
+
45
+ # Count the max depth of a tree
46
+ function countDepth(tree::Node)::Integer
47
+ if tree.degree == 0
48
+ return 1
49
+ elseif tree.degree == 1
50
+ return 1 + countDepth(tree.l)
51
+ else
52
+ return 1 + max(countDepth(tree.l), countDepth(tree.r))
53
+ end
54
+ end
55
+
56
+ # Convert an equation to a string
57
+ function stringTree(tree::Node)::String
58
+ if tree.degree == 0
59
+ if tree.constant
60
+ return string(tree.val)
61
+ else
62
+ if useVarMap
63
+ return varMap[tree.val]
64
+ else
65
+ return "x$(tree.val - 1)"
66
+ end
67
+ end
68
+ elseif tree.degree == 1
69
+ return "$(unaops[tree.op])($(stringTree(tree.l)))"
70
+ else
71
+ return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
72
+ end
73
+ end
74
+
75
+ # Print an equation
76
+ function printTree(tree::Node)
77
+ println(stringTree(tree))
78
+ end
79
+
80
+ # Return a random node from the tree
81
+ function randomNode(tree::Node)::Node
82
+ if tree.degree == 0
83
+ return tree
84
+ end
85
+ a = countNodes(tree)
86
+ b = 0
87
+ c = 0
88
+ if tree.degree >= 1
89
+ b = countNodes(tree.l)
90
+ end
91
+ if tree.degree == 2
92
+ c = countNodes(tree.r)
93
+ end
94
+
95
+ i = rand(1:1+b+c)
96
+ if i <= b
97
+ return randomNode(tree.l)
98
+ elseif i == b + 1
99
+ return tree
100
+ end
101
+
102
+ return randomNode(tree.r)
103
+ end
104
+
105
+ # Count the number of unary operators in the equation
106
+ function countUnaryOperators(tree::Node)::Integer
107
+ if tree.degree == 0
108
+ return 0
109
+ elseif tree.degree == 1
110
+ return 1 + countUnaryOperators(tree.l)
111
+ else
112
+ return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
113
+ end
114
+ end
115
+
116
+ # Count the number of binary operators in the equation
117
+ function countBinaryOperators(tree::Node)::Integer
118
+ if tree.degree == 0
119
+ return 0
120
+ elseif tree.degree == 1
121
+ return 0 + countBinaryOperators(tree.l)
122
+ else
123
+ return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
124
+ end
125
+ end
126
+
127
+ # Count the number of operators in the equation
128
+ function countOperators(tree::Node)::Integer
129
+ return countUnaryOperators(tree) + countBinaryOperators(tree)
130
+ end
131
+
132
+
133
+ # Count the number of constants in an equation
134
+ function countConstants(tree::Node)::Integer
135
+ if tree.degree == 0
136
+ return convert(Integer, tree.constant)
137
+ elseif tree.degree == 1
138
+ return 0 + countConstants(tree.l)
139
+ else
140
+ return 0 + countConstants(tree.l) + countConstants(tree.r)
141
+ end
142
+ end
143
+
144
+ # Get all the constants from a tree
145
+ function getConstants(tree::Node)::Array{Float32, 1}
146
+ if tree.degree == 0
147
+ if tree.constant
148
+ return [tree.val]
149
+ else
150
+ return Float32[]
151
+ end
152
+ elseif tree.degree == 1
153
+ return getConstants(tree.l)
154
+ else
155
+ both = [getConstants(tree.l), getConstants(tree.r)]
156
+ return [constant for subtree in both for constant in subtree]
157
+ end
158
+ end
159
+
160
+ # Set all the constants inside a tree
161
+ function setConstants(tree::Node, constants::Array{Float32, 1})
162
+ if tree.degree == 0
163
+ if tree.constant
164
+ tree.val = constants[1]
165
+ end
166
+ elseif tree.degree == 1
167
+ setConstants(tree.l, constants)
168
+ else
169
+ numberLeft = countConstants(tree.l)
170
+ setConstants(tree.l, constants)
171
+ setConstants(tree.r, constants[numberLeft+1:end])
172
+ end
173
+ end
julia/EvaluateEquation.jl ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluate an equation over an array of datapoints
2
+ function evalTreeArray(tree::Node)::Union{Array{Float32, 1}, Nothing}
3
+ return evalTreeArray(tree, X)
4
+ end
5
+
6
+
7
+ # Evaluate an equation over an array of datapoints
8
+ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
9
+ clen = size(cX)[1]
10
+ if tree.degree == 0
11
+ if tree.constant
12
+ return fill(tree.val, clen)
13
+ else
14
+ return copy(cX[:, tree.val])
15
+ end
16
+ elseif tree.degree == 1
17
+ cumulator = evalTreeArray(tree.l, cX)
18
+ if cumulator === nothing
19
+ return nothing
20
+ end
21
+ op_idx = tree.op
22
+ UNAOP!(cumulator, op_idx, clen)
23
+ @inbounds for i=1:clen
24
+ if isinf(cumulator[i]) || isnan(cumulator[i])
25
+ return nothing
26
+ end
27
+ end
28
+ return cumulator
29
+ else
30
+ cumulator = evalTreeArray(tree.l, cX)
31
+ if cumulator === nothing
32
+ return nothing
33
+ end
34
+ array2 = evalTreeArray(tree.r, cX)
35
+ if array2 === nothing
36
+ return nothing
37
+ end
38
+ op_idx = tree.op
39
+ BINOP!(cumulator, array2, op_idx, clen)
40
+ @inbounds for i=1:clen
41
+ if isinf(cumulator[i]) || isnan(cumulator[i])
42
+ return nothing
43
+ end
44
+ end
45
+ return cumulator
46
+ end
47
+ end
julia/LossFunctions.jl ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Random: randperm
2
+
3
+ # Sum of square error between two arrays
4
+ function SSE(x::Array{Float32}, y::Array{Float32})::Float32
5
+ diff = (x - y)
6
+ return sum(diff .* diff)
7
+ end
8
+ function SSE(x::Nothing, y::Array{Float32})::Float32
9
+ return 1f9
10
+ end
11
+
12
+ # Sum of square error between two arrays, with weights
13
+ function SSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
14
+ diff = (x - y)
15
+ return sum(diff .* diff .* w)
16
+ end
17
+ function SSE(x::Nothing, y::Array{Float32}, w::Array{Float32})::Float32
18
+ return Nothing
19
+ end
20
+
21
+ # Mean of square error between two arrays
22
+ function MSE(x::Nothing, y::Array{Float32})::Float32
23
+ return 1f9
24
+ end
25
+
26
+ # Mean of square error between two arrays
27
+ function MSE(x::Array{Float32}, y::Array{Float32})::Float32
28
+ return SSE(x, y)/size(x)[1]
29
+ end
30
+
31
+ # Mean of square error between two arrays
32
+ function MSE(x::Nothing, y::Array{Float32}, w::Array{Float32})::Float32
33
+ return 1f9
34
+ end
35
+
36
+ # Mean of square error between two arrays
37
+ function MSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
38
+ return SSE(x, y, w)/sum(w)
39
+ end
40
+
41
+ if weighted
42
+ const avgy = sum(y .* weights)/sum(weights)
43
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy), weights)
44
+ else
45
+ const avgy = sum(y)/len
46
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
47
+ end
48
+
49
+ # Score an equation
50
+ function scoreFunc(tree::Node)::Float32
51
+ prediction = evalTreeArray(tree)
52
+ if prediction === nothing
53
+ return 1f9
54
+ end
55
+ if weighted
56
+ mse = MSE(prediction, y, weights)
57
+ else
58
+ mse = MSE(prediction, y)
59
+ end
60
+ return mse / baselineMSE + countNodes(tree)*parsimony
61
+ end
62
+
63
+ # Score an equation with a small batch
64
+ function scoreFuncBatch(tree::Node)::Float32
65
+ # batchSize
66
+ batch_idx = randperm(len)[1:batchSize]
67
+ batch_X = X[batch_idx, :]
68
+ prediction = evalTreeArray(tree, batch_X)
69
+ if prediction === nothing
70
+ return 1f9
71
+ end
72
+ size_adjustment = 1f0
73
+ batch_y = y[batch_idx]
74
+ if weighted
75
+ batch_w = weights[batch_idx]
76
+ mse = MSE(prediction, batch_y, batch_w)
77
+ size_adjustment = 1f0 * len / batchSize
78
+ else
79
+ mse = MSE(prediction, batch_y)
80
+ end
81
+ return size_adjustment * mse / baselineMSE + countNodes(tree)*parsimony
82
+ end
julia/Mutate.jl ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Go through one mutation cycle
2
+ function iterate(member::PopMember, T::Float32, curmaxsize::Integer, frequencyComplexity::Array{Float32, 1})::PopMember
3
+ prev = member.tree
4
+ tree = prev
5
+ #TODO - reconsider this
6
+ if batching
7
+ beforeLoss = scoreFuncBatch(prev)
8
+ else
9
+ beforeLoss = member.score
10
+ end
11
+
12
+ mutationChoice = rand()
13
+ #More constants => more likely to do constant mutation
14
+ weightAdjustmentMutateConstant = min(8, countConstants(prev))/8.0
15
+ cur_weights = copy(mutationWeights) .* 1.0
16
+ cur_weights[1] *= weightAdjustmentMutateConstant
17
+ n = countNodes(prev)
18
+ depth = countDepth(prev)
19
+
20
+ # If equation too big, don't add new operators
21
+ if n >= curmaxsize || depth >= maxdepth
22
+ cur_weights[3] = 0.0
23
+ cur_weights[4] = 0.0
24
+ end
25
+ cur_weights /= sum(cur_weights)
26
+ cweights = cumsum(cur_weights)
27
+
28
+ successful_mutation = false
29
+ #TODO: Currently we dont take this \/ into account
30
+ is_success_always_possible = true
31
+ attempts = 0
32
+ max_attempts = 10
33
+
34
+ #############################################
35
+ # Mutations
36
+ #############################################
37
+ while (!successful_mutation) && attempts < max_attempts
38
+ tree = copyNode(prev)
39
+ successful_mutation = true
40
+ if mutationChoice < cweights[1]
41
+ tree = mutateConstant(tree, T)
42
+
43
+ is_success_always_possible = true
44
+ # Mutating a constant shouldn't invalidate an already-valid function
45
+
46
+ elseif mutationChoice < cweights[2]
47
+ tree = mutateOperator(tree)
48
+
49
+ is_success_always_possible = true
50
+ # Can always mutate to the same operator
51
+
52
+ elseif mutationChoice < cweights[3]
53
+ if rand() < 0.5
54
+ tree = appendRandomOp(tree)
55
+ else
56
+ tree = prependRandomOp(tree)
57
+ end
58
+ is_success_always_possible = false
59
+ # Can potentially have a situation without success
60
+ elseif mutationChoice < cweights[4]
61
+ tree = insertRandomOp(tree)
62
+ is_success_always_possible = false
63
+ elseif mutationChoice < cweights[5]
64
+ tree = deleteRandomOp(tree)
65
+ is_success_always_possible = true
66
+ elseif mutationChoice < cweights[6]
67
+ tree = simplifyTree(tree) # Sometimes we simplify tree
68
+ tree = combineOperators(tree) # See if repeated constants at outer levels
69
+ return PopMember(tree, beforeLoss)
70
+
71
+ is_success_always_possible = true
72
+ # Simplification shouldn't hurt complexity; unless some non-symmetric constraint
73
+ # to commutative operator...
74
+
75
+ elseif mutationChoice < cweights[7]
76
+ tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
77
+
78
+ is_success_always_possible = true
79
+ else # no mutation applied
80
+ return PopMember(tree, beforeLoss)
81
+ end
82
+
83
+ # Check for illegal equations
84
+ for i=1:nbin
85
+ if successful_mutation && flagBinOperatorComplexity(tree, i)
86
+ successful_mutation = false
87
+ end
88
+ end
89
+ for i=1:nuna
90
+ if successful_mutation && flagUnaOperatorComplexity(tree, i)
91
+ successful_mutation = false
92
+ end
93
+ end
94
+
95
+ attempts += 1
96
+ end
97
+ #############################################
98
+
99
+ if !successful_mutation
100
+ return PopMember(copyNode(prev), beforeLoss)
101
+ end
102
+
103
+ if batching
104
+ afterLoss = scoreFuncBatch(tree)
105
+ else
106
+ afterLoss = scoreFunc(tree)
107
+ end
108
+
109
+ if annealing
110
+ delta = afterLoss - beforeLoss
111
+ probChange = exp(-delta/(T*alpha))
112
+ if useFrequency
113
+ oldSize = countNodes(prev)
114
+ newSize = countNodes(tree)
115
+ probChange *= frequencyComplexity[oldSize] / frequencyComplexity[newSize]
116
+ end
117
+
118
+ return_unaltered = (isnan(afterLoss) || probChange < rand())
119
+ if return_unaltered
120
+ return PopMember(copyNode(prev), beforeLoss)
121
+ end
122
+ end
123
+ return PopMember(tree, afterLoss)
124
+ end
julia/MutationFunctions.jl ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Randomly convert an operator into another one (binary->binary;
2
+ # unary->unary)
3
+ function mutateOperator(tree::Node)::Node
4
+ if countOperators(tree) == 0
5
+ return tree
6
+ end
7
+ node = randomNode(tree)
8
+ while node.degree == 0
9
+ node = randomNode(tree)
10
+ end
11
+ if node.degree == 1
12
+ node.op = rand(1:length(unaops))
13
+ else
14
+ node.op = rand(1:length(binops))
15
+ end
16
+ return tree
17
+ end
18
+
19
+ # Randomly perturb a constant
20
+ function mutateConstant(
21
+ tree::Node, T::Float32,
22
+ probNegate::Float32=0.01f0)::Node
23
+ # T is between 0 and 1.
24
+
25
+ if countConstants(tree) == 0
26
+ return tree
27
+ end
28
+ node = randomNode(tree)
29
+ while node.degree != 0 || node.constant == false
30
+ node = randomNode(tree)
31
+ end
32
+
33
+ bottom = 0.1f0
34
+ maxChange = perturbationFactor * T + 1.0f0 + bottom
35
+ factor = maxChange^Float32(rand())
36
+ makeConstBigger = rand() > 0.5
37
+
38
+ if makeConstBigger
39
+ node.val *= factor
40
+ else
41
+ node.val /= factor
42
+ end
43
+
44
+ if rand() > probNegate
45
+ node.val *= -1
46
+ end
47
+
48
+ return tree
49
+ end
50
+
51
+ # Add a random unary/binary operation to the end of a tree
52
+ function appendRandomOp(tree::Node)::Node
53
+ node = randomNode(tree)
54
+ while node.degree != 0
55
+ node = randomNode(tree)
56
+ end
57
+
58
+ choice = rand()
59
+ makeNewBinOp = choice < nbin/nops
60
+ if rand() > 0.5
61
+ left = Float32(randn())
62
+ else
63
+ left = rand(1:nvar)
64
+ end
65
+ if rand() > 0.5
66
+ right = Float32(randn())
67
+ else
68
+ right = rand(1:nvar)
69
+ end
70
+
71
+ if makeNewBinOp
72
+ newnode = Node(
73
+ rand(1:length(binops)),
74
+ left,
75
+ right
76
+ )
77
+ else
78
+ newnode = Node(
79
+ rand(1:length(unaops)),
80
+ left
81
+ )
82
+ end
83
+ node.l = newnode.l
84
+ node.r = newnode.r
85
+ node.op = newnode.op
86
+ node.degree = newnode.degree
87
+ node.val = newnode.val
88
+ node.constant = newnode.constant
89
+ return tree
90
+ end
91
+
92
+ # Insert random node
93
+ function insertRandomOp(tree::Node)::Node
94
+ node = randomNode(tree)
95
+ choice = rand()
96
+ makeNewBinOp = choice < nbin/nops
97
+ left = copyNode(node)
98
+
99
+ if makeNewBinOp
100
+ right = randomConstantNode()
101
+ newnode = Node(
102
+ rand(1:length(binops)),
103
+ left,
104
+ right
105
+ )
106
+ else
107
+ newnode = Node(
108
+ rand(1:length(unaops)),
109
+ left
110
+ )
111
+ end
112
+ node.l = newnode.l
113
+ node.r = newnode.r
114
+ node.op = newnode.op
115
+ node.degree = newnode.degree
116
+ node.val = newnode.val
117
+ node.constant = newnode.constant
118
+ return tree
119
+ end
120
+
121
+ # Add random node to the top of a tree
122
+ function prependRandomOp(tree::Node)::Node
123
+ node = tree
124
+ choice = rand()
125
+ makeNewBinOp = choice < nbin/nops
126
+ left = copyNode(tree)
127
+
128
+ if makeNewBinOp
129
+ right = randomConstantNode()
130
+ newnode = Node(
131
+ rand(1:length(binops)),
132
+ left,
133
+ right
134
+ )
135
+ else
136
+ newnode = Node(
137
+ rand(1:length(unaops)),
138
+ left
139
+ )
140
+ end
141
+ node.l = newnode.l
142
+ node.r = newnode.r
143
+ node.op = newnode.op
144
+ node.degree = newnode.degree
145
+ node.val = newnode.val
146
+ node.constant = newnode.constant
147
+ return node
148
+ end
149
+
150
+ function randomConstantNode()::Node
151
+ if rand() > 0.5
152
+ val = Float32(randn())
153
+ else
154
+ val = rand(1:nvar)
155
+ end
156
+ newnode = Node(val)
157
+ return newnode
158
+ end
159
+
160
+ # Return a random node from the tree with parent
161
+ function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
162
+ if tree.degree == 0
163
+ return tree, parent
164
+ end
165
+ a = countNodes(tree)
166
+ b = 0
167
+ c = 0
168
+ if tree.degree >= 1
169
+ b = countNodes(tree.l)
170
+ end
171
+ if tree.degree == 2
172
+ c = countNodes(tree.r)
173
+ end
174
+
175
+ i = rand(1:1+b+c)
176
+ if i <= b
177
+ return randomNodeAndParent(tree.l, tree)
178
+ elseif i == b + 1
179
+ return tree, parent
180
+ end
181
+
182
+ return randomNodeAndParent(tree.r, tree)
183
+ end
184
+
185
+ # Select a random node, and replace it an the subtree
186
+ # with a variable or constant
187
+ function deleteRandomOp(tree::Node)::Node
188
+ node, parent = randomNodeAndParent(tree, nothing)
189
+ isroot = (parent === nothing)
190
+
191
+ if node.degree == 0
192
+ # Replace with new constant
193
+ newnode = randomConstantNode()
194
+ node.l = newnode.l
195
+ node.r = newnode.r
196
+ node.op = newnode.op
197
+ node.degree = newnode.degree
198
+ node.val = newnode.val
199
+ node.constant = newnode.constant
200
+ elseif node.degree == 1
201
+ # Join one of the children with the parent
202
+ if isroot
203
+ return node.l
204
+ elseif parent.l == node
205
+ parent.l = node.l
206
+ else
207
+ parent.r = node.l
208
+ end
209
+ else
210
+ # Join one of the children with the parent
211
+ if rand() < 0.5
212
+ if isroot
213
+ return node.l
214
+ elseif parent.l == node
215
+ parent.l = node.l
216
+ else
217
+ parent.r = node.l
218
+ end
219
+ else
220
+ if isroot
221
+ return node.r
222
+ elseif parent.l == node
223
+ parent.l = node.r
224
+ else
225
+ parent.r = node.r
226
+ end
227
+ end
228
+ end
229
+ return tree
230
+ end
231
+
232
+ # Create a random equation by appending random operators
233
+ function genRandomTree(length::Integer)::Node
234
+ tree = Node(1.0f0)
235
+ for i=1:length
236
+ tree = appendRandomOp(tree)
237
+ end
238
+ return tree
239
+ end
julia/{operators.jl → Operators.jl} RENAMED
File without changes
julia/PopMember.jl ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Define a member of population by equation, score, and age
2
+ mutable struct PopMember
3
+ tree::Node
4
+ score::Float32
5
+ birth::Integer
6
+
7
+ PopMember(t::Node) = new(t, scoreFunc(t), getTime())
8
+ PopMember(t::Node, score::Float32) = new(t, score, getTime())
9
+
10
+ end
julia/Population.jl ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # A list of members of the population, with easy constructors,
2
+ # which allow for random generation of new populations
3
+ mutable struct Population
4
+ members::Array{PopMember, 1}
5
+ n::Integer
6
+
7
+ Population(pop::Array{PopMember, 1}) = new(pop, size(pop)[1])
8
+ Population(npop::Integer) = new([PopMember(genRandomTree(3)) for i=1:npop], npop)
9
+ Population(npop::Integer, nlength::Integer) = new([PopMember(genRandomTree(nlength)) for i=1:npop], npop)
10
+
11
+ end
12
+
13
+ # Sample 10 random members of the population, and make a new one
14
+ function samplePop(pop::Population)::Population
15
+ idx = rand(1:pop.n, ns)
16
+ return Population(pop.members[idx])
17
+ end
18
+
19
+ # Sample the population, and get the best member from that sample
20
+ function bestOfSample(pop::Population)::PopMember
21
+ sample = samplePop(pop)
22
+ best_idx = argmin([sample.members[member].score for member=1:sample.n])
23
+ return sample.members[best_idx]
24
+ end
25
+
26
+ function finalizeScores(pop::Population)::Population
27
+ need_recalculate = batching
28
+ if need_recalculate
29
+ @inbounds @simd for member=1:pop.n
30
+ pop.members[member].score = scoreFunc(pop.members[member].tree)
31
+ end
32
+ end
33
+ return pop
34
+ end
35
+
36
+ # Return best 10 examples
37
+ function bestSubPop(pop::Population; topn::Integer=10)::Population
38
+ best_idx = sortperm([pop.members[member].score for member=1:pop.n])
39
+ return Population(pop.members[best_idx[1:topn]])
40
+ end
julia/ProgramConstants.jl ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ const maxdegree = 2
3
+ const actualMaxsize = maxsize + maxdegree
4
+ const len = size(X)[1]
5
+
6
+ const nuna = size(unaops)[1]
7
+ const nbin = size(binops)[1]
8
+ const nops = nuna + nbin
9
+ const nvar = size(X)[2];
julia/RegularizedEvolution.jl ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Random: shuffle!
2
+
3
+ # Pass through the population several times, replacing the oldest
4
+ # with the fittest of a small subsample
5
+ function regEvolCycle(pop::Population, T::Float32, curmaxsize::Integer,
6
+ frequencyComplexity::Array{Float32, 1})::Population
7
+ # Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops.
8
+ # but is ultimately a different algorithm than regularized evolution, and might not be
9
+ # as good.
10
+ if fast_cycle
11
+ shuffle!(pop.members)
12
+ n_evol_cycles = round(Integer, pop.n/ns)
13
+ babies = Array{PopMember}(undef, n_evol_cycles)
14
+
15
+ # Iterate each ns-member sub-sample
16
+ @inbounds Threads.@threads for i=1:n_evol_cycles
17
+ best_score = Inf32
18
+ best_idx = 1+(i-1)*ns
19
+ # Calculate best member of the subsample:
20
+ for sub_i=1+(i-1)*ns:i*ns
21
+ if pop.members[sub_i].score < best_score
22
+ best_score = pop.members[sub_i].score
23
+ best_idx = sub_i
24
+ end
25
+ end
26
+ allstar = pop.members[best_idx]
27
+ babies[i] = iterate(allstar, T, curmaxsize, frequencyComplexity)
28
+ end
29
+
30
+ # Replace the n_evol_cycles-oldest members of each population
31
+ @inbounds for i=1:n_evol_cycles
32
+ oldest = argmin([pop.members[member].birth for member=1:pop.n])
33
+ pop.members[oldest] = babies[i]
34
+ end
35
+ else
36
+ for i=1:round(Integer, pop.n/ns)
37
+ allstar = bestOfSample(pop)
38
+ baby = iterate(allstar, T, curmaxsize, frequencyComplexity)
39
+ #printTree(baby.tree)
40
+ oldest = argmin([pop.members[member].birth for member=1:pop.n])
41
+ pop.members[oldest] = baby
42
+ end
43
+ end
44
+
45
+ return pop
46
+ end
julia/SimplifyEquation.jl ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simplify tree
2
+ function combineOperators(tree::Node)::Node
3
+ # NOTE: (const (+*-) const) already accounted for. Call simplifyTree before.
4
+ # ((const + var) + const) => (const + var)
5
+ # ((const * var) * const) => (const * var)
6
+ # ((const - var) - const) => (const - var)
7
+ # (want to add anything commutative!)
8
+ # TODO - need to combine plus/sub if they are both there.
9
+ if tree.degree == 0
10
+ return tree
11
+ elseif tree.degree == 1
12
+ tree.l = combineOperators(tree.l)
13
+ elseif tree.degree == 2
14
+ tree.l = combineOperators(tree.l)
15
+ tree.r = combineOperators(tree.r)
16
+ end
17
+
18
+ top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
19
+ if tree.degree == 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
20
+ op = tree.op
21
+ # Put the constant in r. Need to assume var in left for simplification assumption.
22
+ if tree.l.constant
23
+ tmp = tree.r
24
+ tree.r = tree.l
25
+ tree.l = tmp
26
+ end
27
+ topconstant = tree.r.val
28
+ # Simplify down first
29
+ below = tree.l
30
+ if below.degree == 2 && below.op == op
31
+ if below.l.constant
32
+ tree = below
33
+ tree.l.val = binops[op](tree.l.val, topconstant)
34
+ elseif below.r.constant
35
+ tree = below
36
+ tree.r.val = binops[op](tree.r.val, topconstant)
37
+ end
38
+ end
39
+ end
40
+
41
+ if tree.degree == 2 && binops[tree.op] === sub && top_level_constant
42
+ # Currently just simplifies subtraction. (can't assume both plus and sub are operators)
43
+ # Not commutative, so use different op.
44
+ if tree.l.constant
45
+ if tree.r.degree == 2 && binops[tree.r.op] === sub
46
+ if tree.r.l.constant
47
+ #(const - (const - var)) => (var - const)
48
+ l = tree.l
49
+ r = tree.r
50
+ simplified_const = -(l.val - r.l.val) #neg(sub(l.val, r.l.val))
51
+ tree.l = tree.r.r
52
+ tree.r = l
53
+ tree.r.val = simplified_const
54
+ elseif tree.r.r.constant
55
+ #(const - (var - const)) => (const - var)
56
+ l = tree.l
57
+ r = tree.r
58
+ simplified_const = l.val + r.r.val #plus(l.val, r.r.val)
59
+ tree.r = tree.r.l
60
+ tree.l.val = simplified_const
61
+ end
62
+ end
63
+ else #tree.r.constant is true
64
+ if tree.l.degree == 2 && binops[tree.l.op] === sub
65
+ if tree.l.l.constant
66
+ #((const - var) - const) => (const - var)
67
+ l = tree.l
68
+ r = tree.r
69
+ simplified_const = l.l.val - r.val#sub(l.l.val, r.val)
70
+ tree.r = tree.l.r
71
+ tree.l = r
72
+ tree.l.val = simplified_const
73
+ elseif tree.l.r.constant
74
+ #((var - const) - const) => (var - const)
75
+ l = tree.l
76
+ r = tree.r
77
+ simplified_const = r.val + l.r.val #plus(r.val, l.r.val)
78
+ tree.l = tree.l.l
79
+ tree.r.val = simplified_const
80
+ end
81
+ end
82
+ end
83
+ end
84
+ return tree
85
+ end
86
+
87
+ # Simplify tree
88
+ function simplifyTree(tree::Node)::Node
89
+ if tree.degree == 1
90
+ tree.l = simplifyTree(tree.l)
91
+ if tree.l.degree == 0 && tree.l.constant
92
+ return Node(unaops[tree.op](tree.l.val))
93
+ end
94
+ elseif tree.degree == 2
95
+ tree.l = simplifyTree(tree.l)
96
+ tree.r = simplifyTree(tree.r)
97
+ constantsBelow = (
98
+ tree.l.degree == 0 && tree.l.constant &&
99
+ tree.r.degree == 0 && tree.r.constant
100
+ )
101
+ if constantsBelow
102
+ return Node(binops[tree.op](tree.l.val, tree.r.val))
103
+ end
104
+ end
105
+ return tree
106
+ end
julia/SingleIteration.jl ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cycle through regularized evolution many times,
2
+ # printing the fittest equation every 10% through
3
+ function run(
4
+ pop::Population,
5
+ ncycles::Integer,
6
+ curmaxsize::Integer,
7
+ frequencyComplexity::Array{Float32, 1};
8
+ verbosity::Integer=0
9
+ )::Population
10
+
11
+ allT = LinRange(1.0f0, 0.0f0, ncycles)
12
+ for iT in 1:size(allT)[1]
13
+ if annealing
14
+ pop = regEvolCycle(pop, allT[iT], curmaxsize, frequencyComplexity)
15
+ else
16
+ pop = regEvolCycle(pop, 1.0f0, curmaxsize, frequencyComplexity)
17
+ end
18
+
19
+ if verbosity > 0 && (iT % verbosity == 0)
20
+ bestPops = bestSubPop(pop)
21
+ bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
22
+ bestCurScore = bestPops.members[bestCurScoreIdx].score
23
+ debug(verbosity, bestCurScore, " is the score for ", stringTree(bestPops.members[bestCurScoreIdx].tree))
24
+ end
25
+ end
26
+
27
+ return pop
28
+ end
julia/Utils.jl ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Printf: @printf
2
+
3
+ function id(x::Float32)::Float32
4
+ x
5
+ end
6
+
7
+ function debug(verbosity, string...)
8
+ verbosity > 0 ? println(string...) : nothing
9
+ end
10
+
11
+ function getTime()::Integer
12
+ return round(Integer, 1e3*(time()-1.6e9))
13
+ end
14
+
15
+ # Check for errors before they happen
16
+ function testConfiguration()
17
+ test_input = LinRange(-100f0, 100f0, 99)
18
+
19
+ try
20
+ for left in test_input
21
+ for right in test_input
22
+ for binop in binops
23
+ test_output = binop.(left, right)
24
+ end
25
+ end
26
+ for unaop in unaops
27
+ test_output = unaop.(left)
28
+ end
29
+ end
30
+ catch error
31
+ @printf("\n\nYour configuration is invalid - one of your operators is not well-defined over the real line.\n\n\n")
32
+ throw(error)
33
+ end
34
+ end
julia/halloffame.jl ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # List of the best members seen all time
2
+ mutable struct HallOfFame
3
+ members::Array{PopMember, 1}
4
+ exists::Array{Bool, 1} #Whether it has been set
5
+
6
+ # Arranged by complexity - store one at each.
7
+ HallOfFame() = new([PopMember(Node(1f0), 1f9) for i=1:actualMaxsize], [false for i=1:actualMaxsize])
8
+ end
julia/sr.jl CHANGED
@@ -1,1057 +1,4 @@
1
- import Optim
2
  import Printf: @printf
3
- import Random: shuffle!, randperm
4
-
5
- const maxdegree = 2
6
- const actualMaxsize = maxsize + maxdegree
7
-
8
-
9
- # Sum of square error between two arrays
10
- function SSE(x::Array{Float32}, y::Array{Float32})::Float32
11
- diff = (x - y)
12
- return sum(diff .* diff)
13
- end
14
- function SSE(x::Nothing, y::Array{Float32})::Float32
15
- return 1f9
16
- end
17
-
18
- # Sum of square error between two arrays, with weights
19
- function SSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
20
- diff = (x - y)
21
- return sum(diff .* diff .* w)
22
- end
23
- function SSE(x::Nothing, y::Array{Float32}, w::Array{Float32})::Float32
24
- return Nothing
25
- end
26
-
27
- # Mean of square error between two arrays
28
- function MSE(x::Nothing, y::Array{Float32})::Float32
29
- return 1f9
30
- end
31
-
32
- # Mean of square error between two arrays
33
- function MSE(x::Array{Float32}, y::Array{Float32})::Float32
34
- return SSE(x, y)/size(x)[1]
35
- end
36
-
37
- # Mean of square error between two arrays
38
- function MSE(x::Nothing, y::Array{Float32}, w::Array{Float32})::Float32
39
- return 1f9
40
- end
41
-
42
- # Mean of square error between two arrays
43
- function MSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
44
- return SSE(x, y, w)/sum(w)
45
- end
46
-
47
- const len = size(X)[1]
48
-
49
- if weighted
50
- const avgy = sum(y .* weights)/sum(weights)
51
- const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy), weights)
52
- else
53
- const avgy = sum(y)/len
54
- const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
55
- end
56
-
57
-
58
- function id(x::Float32)::Float32
59
- x
60
- end
61
-
62
- const nuna = size(unaops)[1]
63
- const nbin = size(binops)[1]
64
- const nops = nuna + nbin
65
- const nvar = size(X)[2];
66
-
67
- function debug(verbosity, string...)
68
- verbosity > 0 ? println(string...) : nothing
69
- end
70
-
71
- function getTime()::Integer
72
- return round(Integer, 1e3*(time()-1.6e9))
73
- end
74
-
75
- # Define a serialization format for the symbolic equations:
76
- mutable struct Node
77
- #Holds operators, variables, constants in a tree
78
- degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
79
- val::Union{Float32, Integer} #Either const value, or enumerates variable
80
- constant::Bool #false if variable
81
- op::Integer #enumerates operator (separately for degree=1,2)
82
- l::Union{Node, Nothing}
83
- r::Union{Node, Nothing}
84
-
85
- Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
86
- Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
87
- Node(op::Integer, l::Node) = new(1, 0.0f0, false, op, l, nothing)
88
- Node(op::Integer, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
89
- Node(op::Integer, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
90
-
91
- #Allow to pass the leaf value without additional node call:
92
- Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
93
- Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
94
- Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
95
- end
96
-
97
- # Copy an equation (faster than deepcopy)
98
- function copyNode(tree::Node)::Node
99
- if tree.degree == 0
100
- return Node(tree.val)
101
- elseif tree.degree == 1
102
- return Node(tree.op, copyNode(tree.l))
103
- else
104
- return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
105
- end
106
- end
107
-
108
- # Count the operators, constants, variables in an equation
109
- function countNodes(tree::Node)::Integer
110
- if tree.degree == 0
111
- return 1
112
- elseif tree.degree == 1
113
- return 1 + countNodes(tree.l)
114
- else
115
- return 1 + countNodes(tree.l) + countNodes(tree.r)
116
- end
117
- end
118
-
119
- # Count the max depth of a tree
120
- function countDepth(tree::Node)::Integer
121
- if tree.degree == 0
122
- return 1
123
- elseif tree.degree == 1
124
- return 1 + countDepth(tree.l)
125
- else
126
- return 1 + max(countDepth(tree.l), countDepth(tree.r))
127
- end
128
- end
129
-
130
- # Convert an equation to a string
131
- function stringTree(tree::Node)::String
132
- if tree.degree == 0
133
- if tree.constant
134
- return string(tree.val)
135
- else
136
- if useVarMap
137
- return varMap[tree.val]
138
- else
139
- return "x$(tree.val - 1)"
140
- end
141
- end
142
- elseif tree.degree == 1
143
- return "$(unaops[tree.op])($(stringTree(tree.l)))"
144
- else
145
- return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
146
- end
147
- end
148
-
149
- # Print an equation
150
- function printTree(tree::Node)
151
- println(stringTree(tree))
152
- end
153
-
154
- # Return a random node from the tree
155
- function randomNode(tree::Node)::Node
156
- if tree.degree == 0
157
- return tree
158
- end
159
- a = countNodes(tree)
160
- b = 0
161
- c = 0
162
- if tree.degree >= 1
163
- b = countNodes(tree.l)
164
- end
165
- if tree.degree == 2
166
- c = countNodes(tree.r)
167
- end
168
-
169
- i = rand(1:1+b+c)
170
- if i <= b
171
- return randomNode(tree.l)
172
- elseif i == b + 1
173
- return tree
174
- end
175
-
176
- return randomNode(tree.r)
177
- end
178
-
179
- # Count the number of unary operators in the equation
180
- function countUnaryOperators(tree::Node)::Integer
181
- if tree.degree == 0
182
- return 0
183
- elseif tree.degree == 1
184
- return 1 + countUnaryOperators(tree.l)
185
- else
186
- return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
187
- end
188
- end
189
-
190
- # Count the number of binary operators in the equation
191
- function countBinaryOperators(tree::Node)::Integer
192
- if tree.degree == 0
193
- return 0
194
- elseif tree.degree == 1
195
- return 0 + countBinaryOperators(tree.l)
196
- else
197
- return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
198
- end
199
- end
200
-
201
- # Count the number of operators in the equation
202
- function countOperators(tree::Node)::Integer
203
- return countUnaryOperators(tree) + countBinaryOperators(tree)
204
- end
205
-
206
- # Randomly convert an operator into another one (binary->binary;
207
- # unary->unary)
208
- function mutateOperator(tree::Node)::Node
209
- if countOperators(tree) == 0
210
- return tree
211
- end
212
- node = randomNode(tree)
213
- while node.degree == 0
214
- node = randomNode(tree)
215
- end
216
- if node.degree == 1
217
- node.op = rand(1:length(unaops))
218
- else
219
- node.op = rand(1:length(binops))
220
- end
221
- return tree
222
- end
223
-
224
- # Count the number of constants in an equation
225
- function countConstants(tree::Node)::Integer
226
- if tree.degree == 0
227
- return convert(Integer, tree.constant)
228
- elseif tree.degree == 1
229
- return 0 + countConstants(tree.l)
230
- else
231
- return 0 + countConstants(tree.l) + countConstants(tree.r)
232
- end
233
- end
234
-
235
- # Randomly perturb a constant
236
- function mutateConstant(
237
- tree::Node, T::Float32,
238
- probNegate::Float32=0.01f0)::Node
239
- # T is between 0 and 1.
240
-
241
- if countConstants(tree) == 0
242
- return tree
243
- end
244
- node = randomNode(tree)
245
- while node.degree != 0 || node.constant == false
246
- node = randomNode(tree)
247
- end
248
-
249
- bottom = 0.1f0
250
- maxChange = perturbationFactor * T + 1.0f0 + bottom
251
- factor = maxChange^Float32(rand())
252
- makeConstBigger = rand() > 0.5
253
-
254
- if makeConstBigger
255
- node.val *= factor
256
- else
257
- node.val /= factor
258
- end
259
-
260
- if rand() > probNegate
261
- node.val *= -1
262
- end
263
-
264
- return tree
265
- end
266
-
267
- # Evaluate an equation over an array of datapoints
268
- function evalTreeArray(tree::Node)::Union{Array{Float32, 1}, Nothing}
269
- return evalTreeArray(tree, X)
270
- end
271
-
272
-
273
- # Evaluate an equation over an array of datapoints
274
- function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32, 1}, Nothing}
275
- clen = size(cX)[1]
276
- if tree.degree == 0
277
- if tree.constant
278
- return fill(tree.val, clen)
279
- else
280
- return copy(cX[:, tree.val])
281
- end
282
- elseif tree.degree == 1
283
- cumulator = evalTreeArray(tree.l, cX)
284
- if cumulator === nothing
285
- return nothing
286
- end
287
- op_idx = tree.op
288
- UNAOP!(cumulator, op_idx, clen)
289
- @inbounds for i=1:clen
290
- if isinf(cumulator[i]) || isnan(cumulator[i])
291
- return nothing
292
- end
293
- end
294
- return cumulator
295
- else
296
- cumulator = evalTreeArray(tree.l, cX)
297
- if cumulator === nothing
298
- return nothing
299
- end
300
- array2 = evalTreeArray(tree.r, cX)
301
- if array2 === nothing
302
- return nothing
303
- end
304
- op_idx = tree.op
305
- BINOP!(cumulator, array2, op_idx, clen)
306
- @inbounds for i=1:clen
307
- if isinf(cumulator[i]) || isnan(cumulator[i])
308
- return nothing
309
- end
310
- end
311
- return cumulator
312
- end
313
- end
314
-
315
- # Score an equation
316
- function scoreFunc(tree::Node)::Float32
317
- prediction = evalTreeArray(tree)
318
- if prediction === nothing
319
- return 1f9
320
- end
321
- if weighted
322
- mse = MSE(prediction, y, weights)
323
- else
324
- mse = MSE(prediction, y)
325
- end
326
- return mse / baselineMSE + countNodes(tree)*parsimony
327
- end
328
-
329
- # Score an equation with a small batch
330
- function scoreFuncBatch(tree::Node)::Float32
331
- # batchSize
332
- batch_idx = randperm(len)[1:batchSize]
333
- batch_X = X[batch_idx, :]
334
- prediction = evalTreeArray(tree, batch_X)
335
- if prediction === nothing
336
- return 1f9
337
- end
338
- size_adjustment = 1f0
339
- batch_y = y[batch_idx]
340
- if weighted
341
- batch_w = weights[batch_idx]
342
- mse = MSE(prediction, batch_y, batch_w)
343
- size_adjustment = 1f0 * len / batchSize
344
- else
345
- mse = MSE(prediction, batch_y)
346
- end
347
- return size_adjustment * mse / baselineMSE + countNodes(tree)*parsimony
348
- end
349
-
350
- # Add a random unary/binary operation to the end of a tree
351
- function appendRandomOp(tree::Node)::Node
352
- node = randomNode(tree)
353
- while node.degree != 0
354
- node = randomNode(tree)
355
- end
356
-
357
- choice = rand()
358
- makeNewBinOp = choice < nbin/nops
359
- if rand() > 0.5
360
- left = Float32(randn())
361
- else
362
- left = rand(1:nvar)
363
- end
364
- if rand() > 0.5
365
- right = Float32(randn())
366
- else
367
- right = rand(1:nvar)
368
- end
369
-
370
- if makeNewBinOp
371
- newnode = Node(
372
- rand(1:length(binops)),
373
- left,
374
- right
375
- )
376
- else
377
- newnode = Node(
378
- rand(1:length(unaops)),
379
- left
380
- )
381
- end
382
- node.l = newnode.l
383
- node.r = newnode.r
384
- node.op = newnode.op
385
- node.degree = newnode.degree
386
- node.val = newnode.val
387
- node.constant = newnode.constant
388
- return tree
389
- end
390
-
391
- # Insert random node
392
- function insertRandomOp(tree::Node)::Node
393
- node = randomNode(tree)
394
- choice = rand()
395
- makeNewBinOp = choice < nbin/nops
396
- left = copyNode(node)
397
-
398
- if makeNewBinOp
399
- right = randomConstantNode()
400
- newnode = Node(
401
- rand(1:length(binops)),
402
- left,
403
- right
404
- )
405
- else
406
- newnode = Node(
407
- rand(1:length(unaops)),
408
- left
409
- )
410
- end
411
- node.l = newnode.l
412
- node.r = newnode.r
413
- node.op = newnode.op
414
- node.degree = newnode.degree
415
- node.val = newnode.val
416
- node.constant = newnode.constant
417
- return tree
418
- end
419
-
420
- # Add random node to the top of a tree
421
- function prependRandomOp(tree::Node)::Node
422
- node = tree
423
- choice = rand()
424
- makeNewBinOp = choice < nbin/nops
425
- left = copyNode(tree)
426
-
427
- if makeNewBinOp
428
- right = randomConstantNode()
429
- newnode = Node(
430
- rand(1:length(binops)),
431
- left,
432
- right
433
- )
434
- else
435
- newnode = Node(
436
- rand(1:length(unaops)),
437
- left
438
- )
439
- end
440
- node.l = newnode.l
441
- node.r = newnode.r
442
- node.op = newnode.op
443
- node.degree = newnode.degree
444
- node.val = newnode.val
445
- node.constant = newnode.constant
446
- return node
447
- end
448
-
449
- function randomConstantNode()::Node
450
- if rand() > 0.5
451
- val = Float32(randn())
452
- else
453
- val = rand(1:nvar)
454
- end
455
- newnode = Node(val)
456
- return newnode
457
- end
458
-
459
- # Return a random node from the tree with parent
460
- function randomNodeAndParent(tree::Node, parent::Union{Node, Nothing})::Tuple{Node, Union{Node, Nothing}}
461
- if tree.degree == 0
462
- return tree, parent
463
- end
464
- a = countNodes(tree)
465
- b = 0
466
- c = 0
467
- if tree.degree >= 1
468
- b = countNodes(tree.l)
469
- end
470
- if tree.degree == 2
471
- c = countNodes(tree.r)
472
- end
473
-
474
- i = rand(1:1+b+c)
475
- if i <= b
476
- return randomNodeAndParent(tree.l, tree)
477
- elseif i == b + 1
478
- return tree, parent
479
- end
480
-
481
- return randomNodeAndParent(tree.r, tree)
482
- end
483
-
484
- # Select a random node, and replace it an the subtree
485
- # with a variable or constant
486
- function deleteRandomOp(tree::Node)::Node
487
- node, parent = randomNodeAndParent(tree, nothing)
488
- isroot = (parent === nothing)
489
-
490
- if node.degree == 0
491
- # Replace with new constant
492
- newnode = randomConstantNode()
493
- node.l = newnode.l
494
- node.r = newnode.r
495
- node.op = newnode.op
496
- node.degree = newnode.degree
497
- node.val = newnode.val
498
- node.constant = newnode.constant
499
- elseif node.degree == 1
500
- # Join one of the children with the parent
501
- if isroot
502
- return node.l
503
- elseif parent.l == node
504
- parent.l = node.l
505
- else
506
- parent.r = node.l
507
- end
508
- else
509
- # Join one of the children with the parent
510
- if rand() < 0.5
511
- if isroot
512
- return node.l
513
- elseif parent.l == node
514
- parent.l = node.l
515
- else
516
- parent.r = node.l
517
- end
518
- else
519
- if isroot
520
- return node.r
521
- elseif parent.l == node
522
- parent.l = node.r
523
- else
524
- parent.r = node.r
525
- end
526
- end
527
- end
528
- return tree
529
- end
530
-
531
- # Simplify tree
532
- function combineOperators(tree::Node)::Node
533
- # NOTE: (const (+*-) const) already accounted for. Call simplifyTree before.
534
- # ((const + var) + const) => (const + var)
535
- # ((const * var) * const) => (const * var)
536
- # ((const - var) - const) => (const - var)
537
- # (want to add anything commutative!)
538
- # TODO - need to combine plus/sub if they are both there.
539
- if tree.degree == 0
540
- return tree
541
- elseif tree.degree == 1
542
- tree.l = combineOperators(tree.l)
543
- elseif tree.degree == 2
544
- tree.l = combineOperators(tree.l)
545
- tree.r = combineOperators(tree.r)
546
- end
547
-
548
- top_level_constant = tree.degree == 2 && (tree.l.constant || tree.r.constant)
549
- if tree.degree == 2 && (binops[tree.op] === mult || binops[tree.op] === plus) && top_level_constant
550
- op = tree.op
551
- # Put the constant in r. Need to assume var in left for simplification assumption.
552
- if tree.l.constant
553
- tmp = tree.r
554
- tree.r = tree.l
555
- tree.l = tmp
556
- end
557
- topconstant = tree.r.val
558
- # Simplify down first
559
- below = tree.l
560
- if below.degree == 2 && below.op == op
561
- if below.l.constant
562
- tree = below
563
- tree.l.val = binops[op](tree.l.val, topconstant)
564
- elseif below.r.constant
565
- tree = below
566
- tree.r.val = binops[op](tree.r.val, topconstant)
567
- end
568
- end
569
- end
570
-
571
- if tree.degree == 2 && binops[tree.op] === sub && top_level_constant
572
- # Currently just simplifies subtraction. (can't assume both plus and sub are operators)
573
- # Not commutative, so use different op.
574
- if tree.l.constant
575
- if tree.r.degree == 2 && binops[tree.r.op] === sub
576
- if tree.r.l.constant
577
- #(const - (const - var)) => (var - const)
578
- l = tree.l
579
- r = tree.r
580
- simplified_const = -(l.val - r.l.val) #neg(sub(l.val, r.l.val))
581
- tree.l = tree.r.r
582
- tree.r = l
583
- tree.r.val = simplified_const
584
- elseif tree.r.r.constant
585
- #(const - (var - const)) => (const - var)
586
- l = tree.l
587
- r = tree.r
588
- simplified_const = l.val + r.r.val #plus(l.val, r.r.val)
589
- tree.r = tree.r.l
590
- tree.l.val = simplified_const
591
- end
592
- end
593
- else #tree.r.constant is true
594
- if tree.l.degree == 2 && binops[tree.l.op] === sub
595
- if tree.l.l.constant
596
- #((const - var) - const) => (const - var)
597
- l = tree.l
598
- r = tree.r
599
- simplified_const = l.l.val - r.val#sub(l.l.val, r.val)
600
- tree.r = tree.l.r
601
- tree.l = r
602
- tree.l.val = simplified_const
603
- elseif tree.l.r.constant
604
- #((var - const) - const) => (var - const)
605
- l = tree.l
606
- r = tree.r
607
- simplified_const = r.val + l.r.val #plus(r.val, l.r.val)
608
- tree.l = tree.l.l
609
- tree.r.val = simplified_const
610
- end
611
- end
612
- end
613
- end
614
- return tree
615
- end
616
-
617
- # Simplify tree
618
- function simplifyTree(tree::Node)::Node
619
- if tree.degree == 1
620
- tree.l = simplifyTree(tree.l)
621
- if tree.l.degree == 0 && tree.l.constant
622
- return Node(unaops[tree.op](tree.l.val))
623
- end
624
- elseif tree.degree == 2
625
- tree.l = simplifyTree(tree.l)
626
- tree.r = simplifyTree(tree.r)
627
- constantsBelow = (
628
- tree.l.degree == 0 && tree.l.constant &&
629
- tree.r.degree == 0 && tree.r.constant
630
- )
631
- if constantsBelow
632
- return Node(binops[tree.op](tree.l.val, tree.r.val))
633
- end
634
- end
635
- return tree
636
- end
637
-
638
- # Define a member of population by equation, score, and age
639
- mutable struct PopMember
640
- tree::Node
641
- score::Float32
642
- birth::Integer
643
-
644
- PopMember(t::Node) = new(t, scoreFunc(t), getTime())
645
- PopMember(t::Node, score::Float32) = new(t, score, getTime())
646
-
647
- end
648
-
649
- # Check if any binary operator are overly complex
650
- function flagBinOperatorComplexity(tree::Node, op::Int)::Bool
651
- if tree.degree == 0
652
- return false
653
- elseif tree.degree == 1
654
- return flagBinOperatorComplexity(tree.l, op)
655
- else
656
- if tree.op == op
657
- overly_complex = (
658
- ((bin_constraints[op][1] > -1) &&
659
- (countNodes(tree.l) > bin_constraints[op][1]))
660
- ||
661
- ((bin_constraints[op][2] > -1) &&
662
- (countNodes(tree.r) > bin_constraints[op][2]))
663
- )
664
- if overly_complex
665
- return true
666
- end
667
- end
668
- return (flagBinOperatorComplexity(tree.l, op) || flagBinOperatorComplexity(tree.r, op))
669
- end
670
- end
671
-
672
- # Check if any unary operators are overly complex
673
- function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
674
- if tree.degree == 0
675
- return false
676
- elseif tree.degree == 1
677
- if tree.op == op
678
- overly_complex = (
679
- (una_constraints[op] > -1) &&
680
- (countNodes(tree.l) > una_constraints[op])
681
- )
682
- if overly_complex
683
- return true
684
- end
685
- end
686
- return flagUnaOperatorComplexity(tree.l, op)
687
- else
688
- return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
689
- end
690
- end
691
-
692
- # Go through one simulated annealing mutation cycle
693
- # exp(-delta/T) defines probability of accepting a change
694
- function iterate(member::PopMember, T::Float32, curmaxsize::Integer, frequencyComplexity::Array{Float32, 1})::PopMember
695
- prev = member.tree
696
- tree = prev
697
- #TODO - reconsider this
698
- if batching
699
- beforeLoss = scoreFuncBatch(prev)
700
- else
701
- beforeLoss = member.score
702
- end
703
-
704
- mutationChoice = rand()
705
- #More constants => more likely to do constant mutation
706
- weightAdjustmentMutateConstant = min(8, countConstants(prev))/8.0
707
- cur_weights = copy(mutationWeights) .* 1.0
708
- cur_weights[1] *= weightAdjustmentMutateConstant
709
- n = countNodes(prev)
710
- depth = countDepth(prev)
711
-
712
- # If equation too big, don't add new operators
713
- if n >= curmaxsize || depth >= maxdepth
714
- cur_weights[3] = 0.0
715
- cur_weights[4] = 0.0
716
- end
717
- cur_weights /= sum(cur_weights)
718
- cweights = cumsum(cur_weights)
719
-
720
- successful_mutation = false
721
- #TODO: Currently we dont take this \/ into account
722
- is_success_always_possible = true
723
- attempts = 0
724
- max_attempts = 10
725
-
726
- #############################################
727
- # Mutations
728
- #############################################
729
- while (!successful_mutation) && attempts < max_attempts
730
- tree = copyNode(prev)
731
- successful_mutation = true
732
- if mutationChoice < cweights[1]
733
- tree = mutateConstant(tree, T)
734
-
735
- is_success_always_possible = true
736
- # Mutating a constant shouldn't invalidate an already-valid function
737
-
738
- elseif mutationChoice < cweights[2]
739
- tree = mutateOperator(tree)
740
-
741
- is_success_always_possible = true
742
- # Can always mutate to the same operator
743
-
744
- elseif mutationChoice < cweights[3]
745
- if rand() < 0.5
746
- tree = appendRandomOp(tree)
747
- else
748
- tree = prependRandomOp(tree)
749
- end
750
- is_success_always_possible = false
751
- # Can potentially have a situation without success
752
- elseif mutationChoice < cweights[4]
753
- tree = insertRandomOp(tree)
754
- is_success_always_possible = false
755
- elseif mutationChoice < cweights[5]
756
- tree = deleteRandomOp(tree)
757
- is_success_always_possible = true
758
- elseif mutationChoice < cweights[6]
759
- tree = simplifyTree(tree) # Sometimes we simplify tree
760
- tree = combineOperators(tree) # See if repeated constants at outer levels
761
- return PopMember(tree, beforeLoss)
762
-
763
- is_success_always_possible = true
764
- # Simplification shouldn't hurt complexity; unless some non-symmetric constraint
765
- # to commutative operator...
766
-
767
- elseif mutationChoice < cweights[7]
768
- tree = genRandomTree(5) # Sometimes we generate a new tree completely tree
769
-
770
- is_success_always_possible = true
771
- else # no mutation applied
772
- return PopMember(tree, beforeLoss)
773
- end
774
-
775
- # Check for illegal equations
776
- for i=1:nbin
777
- if successful_mutation && flagBinOperatorComplexity(tree, i)
778
- successful_mutation = false
779
- end
780
- end
781
- for i=1:nuna
782
- if successful_mutation && flagUnaOperatorComplexity(tree, i)
783
- successful_mutation = false
784
- end
785
- end
786
-
787
- attempts += 1
788
- end
789
- #############################################
790
-
791
- if !successful_mutation
792
- return PopMember(copyNode(prev), beforeLoss)
793
- end
794
-
795
- if batching
796
- afterLoss = scoreFuncBatch(tree)
797
- else
798
- afterLoss = scoreFunc(tree)
799
- end
800
-
801
- if annealing
802
- delta = afterLoss - beforeLoss
803
- probChange = exp(-delta/(T*alpha))
804
- if useFrequency
805
- oldSize = countNodes(prev)
806
- newSize = countNodes(tree)
807
- probChange *= frequencyComplexity[oldSize] / frequencyComplexity[newSize]
808
- end
809
-
810
- return_unaltered = (isnan(afterLoss) || probChange < rand())
811
- if return_unaltered
812
- return PopMember(copyNode(prev), beforeLoss)
813
- end
814
- end
815
- return PopMember(tree, afterLoss)
816
- end
817
-
818
- # Create a random equation by appending random operators
819
- function genRandomTree(length::Integer)::Node
820
- tree = Node(1.0f0)
821
- for i=1:length
822
- tree = appendRandomOp(tree)
823
- end
824
- return tree
825
- end
826
-
827
-
828
- # A list of members of the population, with easy constructors,
829
- # which allow for random generation of new populations
830
- mutable struct Population
831
- members::Array{PopMember, 1}
832
- n::Integer
833
-
834
- Population(pop::Array{PopMember, 1}) = new(pop, size(pop)[1])
835
- Population(npop::Integer) = new([PopMember(genRandomTree(3)) for i=1:npop], npop)
836
- Population(npop::Integer, nlength::Integer) = new([PopMember(genRandomTree(nlength)) for i=1:npop], npop)
837
-
838
- end
839
-
840
- # Sample 10 random members of the population, and make a new one
841
- function samplePop(pop::Population)::Population
842
- idx = rand(1:pop.n, ns)
843
- return Population(pop.members[idx])
844
- end
845
-
846
- # Sample the population, and get the best member from that sample
847
- function bestOfSample(pop::Population)::PopMember
848
- sample = samplePop(pop)
849
- best_idx = argmin([sample.members[member].score for member=1:sample.n])
850
- return sample.members[best_idx]
851
- end
852
-
853
- function finalizeScores(pop::Population)::Population
854
- need_recalculate = batching
855
- if need_recalculate
856
- @inbounds @simd for member=1:pop.n
857
- pop.members[member].score = scoreFunc(pop.members[member].tree)
858
- end
859
- end
860
- return pop
861
- end
862
-
863
- # Return best 10 examples
864
- function bestSubPop(pop::Population; topn::Integer=10)::Population
865
- best_idx = sortperm([pop.members[member].score for member=1:pop.n])
866
- return Population(pop.members[best_idx[1:topn]])
867
- end
868
-
869
- # Pass through the population several times, replacing the oldest
870
- # with the fittest of a small subsample
871
- function regEvolCycle(pop::Population, T::Float32, curmaxsize::Integer,
872
- frequencyComplexity::Array{Float32, 1})::Population
873
- # Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops.
874
- # but is ultimately a different algorithm than regularized evolution, and might not be
875
- # as good.
876
- if fast_cycle
877
- shuffle!(pop.members)
878
- n_evol_cycles = round(Integer, pop.n/ns)
879
- babies = Array{PopMember}(undef, n_evol_cycles)
880
-
881
- # Iterate each ns-member sub-sample
882
- @inbounds Threads.@threads for i=1:n_evol_cycles
883
- best_score = Inf32
884
- best_idx = 1+(i-1)*ns
885
- # Calculate best member of the subsample:
886
- for sub_i=1+(i-1)*ns:i*ns
887
- if pop.members[sub_i].score < best_score
888
- best_score = pop.members[sub_i].score
889
- best_idx = sub_i
890
- end
891
- end
892
- allstar = pop.members[best_idx]
893
- babies[i] = iterate(allstar, T, curmaxsize, frequencyComplexity)
894
- end
895
-
896
- # Replace the n_evol_cycles-oldest members of each population
897
- @inbounds for i=1:n_evol_cycles
898
- oldest = argmin([pop.members[member].birth for member=1:pop.n])
899
- pop.members[oldest] = babies[i]
900
- end
901
- else
902
- for i=1:round(Integer, pop.n/ns)
903
- allstar = bestOfSample(pop)
904
- baby = iterate(allstar, T, curmaxsize, frequencyComplexity)
905
- #printTree(baby.tree)
906
- oldest = argmin([pop.members[member].birth for member=1:pop.n])
907
- pop.members[oldest] = baby
908
- end
909
- end
910
-
911
- return pop
912
- end
913
-
914
- # Cycle through regularized evolution many times,
915
- # printing the fittest equation every 10% through
916
- function run(
917
- pop::Population,
918
- ncycles::Integer,
919
- curmaxsize::Integer,
920
- frequencyComplexity::Array{Float32, 1};
921
- verbosity::Integer=0
922
- )::Population
923
-
924
- allT = LinRange(1.0f0, 0.0f0, ncycles)
925
- for iT in 1:size(allT)[1]
926
- if annealing
927
- pop = regEvolCycle(pop, allT[iT], curmaxsize, frequencyComplexity)
928
- else
929
- pop = regEvolCycle(pop, 1.0f0, curmaxsize, frequencyComplexity)
930
- end
931
-
932
- if verbosity > 0 && (iT % verbosity == 0)
933
- bestPops = bestSubPop(pop)
934
- bestCurScoreIdx = argmin([bestPops.members[member].score for member=1:bestPops.n])
935
- bestCurScore = bestPops.members[bestCurScoreIdx].score
936
- debug(verbosity, bestCurScore, " is the score for ", stringTree(bestPops.members[bestCurScoreIdx].tree))
937
- end
938
- end
939
-
940
- return pop
941
- end
942
-
943
- # Get all the constants from a tree
944
- function getConstants(tree::Node)::Array{Float32, 1}
945
- if tree.degree == 0
946
- if tree.constant
947
- return [tree.val]
948
- else
949
- return Float32[]
950
- end
951
- elseif tree.degree == 1
952
- return getConstants(tree.l)
953
- else
954
- both = [getConstants(tree.l), getConstants(tree.r)]
955
- return [constant for subtree in both for constant in subtree]
956
- end
957
- end
958
-
959
- # Set all the constants inside a tree
960
- function setConstants(tree::Node, constants::Array{Float32, 1})
961
- if tree.degree == 0
962
- if tree.constant
963
- tree.val = constants[1]
964
- end
965
- elseif tree.degree == 1
966
- setConstants(tree.l, constants)
967
- else
968
- numberLeft = countConstants(tree.l)
969
- setConstants(tree.l, constants)
970
- setConstants(tree.r, constants[numberLeft+1:end])
971
- end
972
- end
973
-
974
-
975
- # Proxy function for optimization
976
- function optFunc(x::Array{Float32, 1}, tree::Node)::Float32
977
- setConstants(tree, x)
978
- return scoreFunc(tree)
979
- end
980
-
981
- # Use Nelder-Mead to optimize the constants in an equation
982
- function optimizeConstants(member::PopMember)::PopMember
983
- nconst = countConstants(member.tree)
984
- if nconst == 0
985
- return member
986
- end
987
- x0 = getConstants(member.tree)
988
- f(x::Array{Float32,1})::Float32 = optFunc(x, member.tree)
989
- if size(x0)[1] == 1
990
- algorithm = Optim.Newton
991
- else
992
- algorithm = Optim.NelderMead
993
- end
994
-
995
- try
996
- result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
997
- # Try other initial conditions:
998
- for i=1:nrestarts
999
- tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
1000
- if tmpresult.minimum < result.minimum
1001
- result = tmpresult
1002
- end
1003
- end
1004
-
1005
- if Optim.converged(result)
1006
- setConstants(member.tree, result.minimizer)
1007
- member.score = convert(Float32, result.minimum)
1008
- member.birth = getTime()
1009
- else
1010
- setConstants(member.tree, x0)
1011
- end
1012
- catch error
1013
- # Fine if optimization encountered domain error, just return x0
1014
- if isa(error, AssertionError)
1015
- setConstants(member.tree, x0)
1016
- else
1017
- throw(error)
1018
- end
1019
- end
1020
- return member
1021
- end
1022
-
1023
-
1024
- # List of the best members seen all time
1025
- mutable struct HallOfFame
1026
- members::Array{PopMember, 1}
1027
- exists::Array{Bool, 1} #Whether it has been set
1028
-
1029
- # Arranged by complexity - store one at each.
1030
- HallOfFame() = new([PopMember(Node(1f0), 1f9) for i=1:actualMaxsize], [false for i=1:actualMaxsize])
1031
- end
1032
-
1033
-
1034
- # Check for errors before they happen
1035
- function testConfiguration()
1036
- test_input = LinRange(-100f0, 100f0, 99)
1037
-
1038
- try
1039
- for left in test_input
1040
- for right in test_input
1041
- for binop in binops
1042
- test_output = binop.(left, right)
1043
- end
1044
- end
1045
- for unaop in unaops
1046
- test_output = unaop.(left)
1047
- end
1048
- end
1049
- catch error
1050
- @printf("\n\nYour configuration is invalid - one of your operators is not well-defined over the real line.\n\n\n")
1051
- throw(error)
1052
- end
1053
- end
1054
-
1055
 
1056
  function fullRun(niterations::Integer;
1057
  npop::Integer=300,
 
 
1
  import Printf: @printf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  function fullRun(niterations::Integer;
4
  npop::Integer=300,
julia/truth.jl ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *** Custom Functions
2
+ ##################################################################################################################################
3
+ # *** Will somewhere need to define a list TRUTHS of all valid auxliary truths
4
+ struct Transformation
5
+ type::Integer # 1 is symmetry, 2 is zero, 3 is equality
6
+ params::Array{Int32}
7
+ Transformation(type::Integer, params::Array{Int32}) = new(type, params)
8
+ Transformation(type::Integer, params::Array{Int64}) = new(type, params)
9
+
10
+ end
11
+ struct Truth
12
+ transformation::Transformation
13
+ weights::Array{Float32}
14
+ Truth(transformation::Transformation, weights::Array{Float32}) = new(transformation, weights)
15
+ Truth(type::Int64, params::Array{Int64}, weights::Array{Float32}) = new(Transformation(type, params), weights)
16
+ Truth(transformation::Transformation, weights::Array{Float64}) = new(transformation, weights)
17
+ Truth(type::Int64, params::Array{Int64}, weights::Array{Float64}) = new(Transformation(type, params), weights)
18
+ end
19
+ # Returns a linear combination when given X of shape nxd, y of shape nx1 is f(x) and w of shape d+2x1, result is shape nx1
20
+ function LinearPrediction(cX::Array{Float32}, cy::Array{Float32}, w::Array{Float32})::Array{Float32}
21
+ preds = 0
22
+ for i in 1:ndims(cX)
23
+ preds = preds .+ cX[:,i].*w[i]
24
+ end
25
+ preds = preds .+ cy.*w[ndims(cX)+1]
26
+ return preds .+ w[ndims(cX)+2]
27
+ end
28
+
29
+ # Returns a copy of the data with the two specified columns swapped
30
+ function swapColumns(cX::Array{Float32, 2}, a::Integer, b::Integer)::Array{Float32, 2}
31
+ X1 = copy(cX)
32
+ X1[:, a] = cX[:, b]
33
+ X1[:, b] = cX[:, a]
34
+ return X1
35
+ end
36
+
37
+ # Returns a copy of the data with the specified integers in the list set to value given
38
+ function setVal(cX::Array{Float32, 2}, a::Array{Int32, 1}, val::Float32)::Array{Float32, 2}
39
+ X1 = copy(cX)
40
+ for i in 1:size(a)[1]
41
+ X1[:, a[i]] = fill!(cX[:, a[i]], val)
42
+ end
43
+ return X1
44
+ end
45
+
46
+ # Returns a copy of the data with the specified integer indices in the list set to the first item of that list
47
+ function setEq(cX::Array{Float32, 2}, a::Array{Int32, 1})::Array{Float32, 2}
48
+ X1 = copy(cX)
49
+ val = X1[:, a[1]]
50
+ for i in 1:size(a)[1]
51
+ X1[:, a[i]] = val
52
+ end
53
+ return X1
54
+ end
55
+
56
+ # Takes in a dataset and returns the transformed version of it as per the specified type and parameters
57
+ function transform(cX::Array{Float32, 2}, transformation::Transformation)::Array{Float32, 2}
58
+ if transformation.type==1 # then symmetry
59
+ a = transformation.params[1]
60
+ b = transformation.params[2]
61
+ return swapColumns(cX, a, b)
62
+ elseif transformation.type==2 # then zero condition
63
+ return setVal(cX, transformation.params, Float32(0))
64
+ elseif transformation.type == 3 # then equality condition
65
+ return setEq(cX, transformation.params)
66
+ else # Then error return X
67
+ return cX
68
+ end
69
+ end
70
+ function transform(cX::Array{Float32, 2}, truth::Truth)::Array{Float32, 2}
71
+ return transform(cX, truth.transformation)
72
+ end
73
+
74
+ # Takes in X that has been transformed and returns what the Truth projects the target values should be
75
+ function truthPrediction(X_transformed::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Array{Float32}
76
+ return LinearPrediction(X_transformed, cy, truth.weights)
77
+ end
julia/truthPops.jl ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Returns the MSE between the predictions and the truth provided targets for the given dataset
2
+ function truthScore(member::PopMember, cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Float32
3
+ transformed = transform(cX, truth)
4
+ targets = truthPrediction(transformed, cy, truth)
5
+ preds = evalTreeArray(member.tree, transformed)
6
+ return MSE(preds, targets)
7
+ end
8
+
9
+ # Assumes a dataset X, y for a given truth
10
+ function truthScore(member::PopMember, truth::Truth)::Float32
11
+ return truthScore(member, X, y, truth)
12
+ end
13
+
14
+ # Assumes a list of Truths TRUTHS is defined. Performs the truthScore function for each of them and returns the average
15
+ function truthScore(member::PopMember, cX::Array{Float32, 2}, cy::Array{Float32})::Float32
16
+ s = 0
17
+ for truth in TRUTHS
18
+ s += (truthScore(member, cX, cy, truth))/size(TRUTHS)[1]
19
+ end
20
+ return s
21
+ end
22
+
23
+ # Assumes list of Truths TRUTHS and dataset X, y are defined
24
+ function truthScore(member::PopMember)::Float32
25
+ return truthScore(member, X, y)
26
+ end
27
+ # Returns the MSE between the predictions and the truth provided targets for the given dataset
28
+ function truthScore(tree::Node, cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Float32
29
+ transformed = transform(cX, truth)
30
+ targets = truthPrediction(transformed, cy, truth)
31
+ preds = evalTreeArray(tree, transformed)
32
+ return MSE(preds, targets)
33
+ end
34
+
35
+ # Assumes a dataset X, y for a given truth
36
+ function truthScore(tree::Node, truth::Truth)::Float32
37
+ return truthScore(tree, X, y, truth)
38
+ end
39
+
40
+ # Assumes a list of Truths TRUTHS is defined. Performs the truthScore function for each of them and returns the average
41
+ function truthScore(tree::Node, cX::Array{Float32, 2}, cy::Array{Float32})::Float32
42
+ s = 0
43
+ for truth in TRUTHS
44
+ s += (truthScore(tree, cX, cy, truth))/size(TRUTHS)[1]
45
+ end
46
+ return s
47
+ end
48
+
49
+ # Assumes list of Truths TRUTHS and dataset X, y are defined
50
+ function truthScore(tree::Node)::Float32
51
+ return truthScore(tree, X, y)
52
+ end
53
+
54
+ # Returns true iff Truth Score is below a given threshold i.e truth is satisfied
55
+ function testTruth(member::PopMember, truth::Truth, threshold::Float32=Float32(1.0e-8))::Bool
56
+ truthError = truthScore(member, truth)
57
+ #print(stringTree(member.tree), "\n")
58
+ #print(truth, ": ")
59
+ #print(truthError, "\n")
60
+ if truthError > threshold
61
+ #print("Returns False \n ----\n")
62
+ return false
63
+ else
64
+ #print("Returns True \n ----\n")
65
+ return true
66
+ end
67
+ end
68
+
69
+ # Returns a list of violating functions from assumed list TRUTHS
70
+ function violatingTruths(member::PopMember)::Array{Truth}
71
+ return violatingTruths(member.tree)
72
+ end
73
+
74
+ # Returns true iff Truth Score is below a given threshold i.e truth is satisfied
75
+ function testTruth(tree::Node, truth::Truth, threshold::Float32=Float32(1.0e-3))::Bool
76
+ truthError = truthScore(tree, truth)
77
+ if truthError > threshold
78
+ return false
79
+ else
80
+ return true
81
+ end
82
+ end
83
+
84
+ # Returns a list of violating functions from assumed list TRUTHS
85
+ function violatingTruths(tree::Node)::Array{Truth}
86
+ toReturn = []
87
+ #print("\n Checking Equation ", stringTree(tree), "\n")
88
+ for truth in TRUTHS
89
+ test_truth = testTruth(tree, truth)
90
+ #print("Truth: ", truth, ": " , test_truth, "\n-----\n")
91
+ if !test_truth
92
+ append!(toReturn, [truth])
93
+ end
94
+ end
95
+ return toReturn
96
+ end
97
+
98
+ function randomIndex(cX::Array{Float32, 2}, k::Integer=10)::Array{Int32, 1}
99
+ indxs = sample([Int32(i) for i in 1:size(cX)[1]], k)
100
+ return indxs
101
+ end
102
+
103
+ function randomIndex(leng::Integer, k::Integer=10)::Array{Int32, 1}
104
+ indxs = sample([Int32(i) for i in 1:leng], k)
105
+ return indxs
106
+ end
107
+
108
+ function extendedX(cX::Array{Float32, 2}, truth::Truth, indx::Array{Int32, 1})::Array{Float32, 2}
109
+ workingcX = copy(cX)
110
+ X_slice = workingcX[indx, :]
111
+ X_transformed = transform(X_slice, truth)
112
+ return X_transformed
113
+ end
114
+ function extendedX(truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
115
+ return extendedX(OGX, truth, indx)
116
+ end
117
+ function extendedX(cX::Array{Float32, 2}, violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
118
+ if length(violatedTruths) == 0
119
+ return nothing
120
+ end
121
+ workingX = extendedX(cX, violatedTruths[1], indx)
122
+ for truth in violatedTruths[2:length(violatedTruths)]
123
+ workingX = vcat(workingX, extendedX(cX, truth, indx))
124
+ end
125
+ return workingX
126
+ end
127
+ function extendedX(violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
128
+ return extendedX(OGX, violatedTruths, indx)
129
+ end
130
+ function extendedX(tree::Node, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
131
+ violatedTruths = violatingTruths(tree)
132
+ return extendedX(violatedTruths, indx)
133
+ end
134
+ function extendedX(member::PopMember, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
135
+ return extendedX(member.tree, indx)
136
+ end
137
+
138
+
139
+ function extendedy(cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
140
+ cy = copy(cy)
141
+ cX = copy(cX)
142
+ X_slice = cX[indx, :]
143
+ y_slice = cy[indx]
144
+ X_transformed = transform(X_slice, truth)
145
+ y_transformed = truthPrediction(X_transformed, y_slice, truth)
146
+ return y_transformed
147
+ end
148
+ function extendedy(truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
149
+ return extendedy(OGX, OGy, truth, indx)
150
+ end
151
+ function extendedy(cX::Array{Float32, 2}, cy::Array{Float32}, violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
152
+ if length(violatedTruths) == 0
153
+ return nothing
154
+ end
155
+ workingy = extendedy(cX, cy, violatedTruths[1], indx)
156
+ for truth in violatedTruths[2:length(violatedTruths)]
157
+ workingy = vcat(workingy, extendedy(cX, cy, truth, indx))
158
+ end
159
+ return workingy
160
+ end
161
+ function extendedy(violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
162
+ return extendedy(OGX,OGy, violatedTruths, indx)
163
+ end
164
+ function extendedy(tree::Node, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
165
+ violatedTruths = violatingTruths(tree)
166
+ return extendedy(violatedTruths, indx)
167
+ end
168
+ function extendedy(member::PopMember, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
169
+ return extendedy(member.tree, indx)
170
+ end
pysr/sr.py CHANGED
@@ -192,15 +192,7 @@ def pysr(X=None, y=None, weights=None,
192
  (as strings).
193
 
194
  """
195
- if threads is not None:
196
- raise ValueError("The threads kwarg is deprecated. Use procs.")
197
- if limitPowComplexity:
198
- raise ValueError("The limitPowComplexity kwarg is deprecated. Use constraints.")
199
- if maxdepth is None:
200
- maxdepth = maxsize
201
- if equation_file is None:
202
- date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
203
- equation_file = 'hall_of_fame_' + date_time + '.csv'
204
 
205
  if isinstance(X, pd.DataFrame):
206
  variable_names = list(X.columns)
@@ -211,119 +203,165 @@ def pysr(X=None, y=None, weights=None,
211
  if len(X.shape) == 1:
212
  X = X[:, None]
213
 
214
- # Check for potential errors before they happen
215
- assert len(unary_operators) + len(binary_operators) > 0
216
- assert len(X.shape) == 2
217
- assert len(y.shape) == 1
218
- assert X.shape[0] == y.shape[0]
219
- if weights is not None:
220
- assert len(weights.shape) == 1
221
- assert X.shape[0] == weights.shape[0]
222
- if use_custom_variable_names:
223
- assert len(variable_names) == X.shape[1]
224
 
225
 
226
  if len(X) > 10000 and not batching:
227
  warnings.warn("Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed.")
228
 
229
- if select_k_features is not None:
230
- selection = run_feature_selection(X, y, select_k_features)
231
- print(f"Using features {selection}")
232
- X = X[:, selection]
233
-
234
- if use_custom_variable_names:
235
- variable_names = [variable_names[selection[i]] for i in range(len(selection))]
236
 
 
 
 
 
 
237
  if populations is None:
238
  populations = procs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- if isinstance(binary_operators, str): binary_operators = [binary_operators]
241
- if isinstance(unary_operators, str): unary_operators = [unary_operators]
242
 
243
- if X is None:
244
- if test == 'simple1':
245
- eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
246
- elif test == 'simple2':
247
- eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
248
- elif test == 'simple3':
249
- eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
250
- elif test == 'simple4':
251
- eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
252
- elif test == 'simple5':
253
- eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
254
-
255
- X = np.random.randn(100, 5)*3
256
- y = eval(eval_str)
257
- print("Running on", eval_str)
258
 
259
- # System-independent paths
260
- pkg_directory = Path(__file__).parents[1] / 'julia'
261
- pkg_filename = pkg_directory / "sr.jl"
262
- operator_filename = pkg_directory / "operators.jl"
263
 
264
- tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
265
- hyperparam_filename = tmpdir / f'hyperparams.jl'
266
- dataset_filename = tmpdir / f'dataset.jl'
267
- runfile_filename = tmpdir / f'runfile.jl'
268
- X_filename = tmpdir / "X.csv"
269
- y_filename = tmpdir / "y.csv"
270
- weights_filename = tmpdir / "weights.csv"
271
 
272
- def_hyperparams = ""
273
 
274
- # Add pre-defined functions to Julia
275
- for op_list in [binary_operators, unary_operators]:
276
- for i in range(len(op_list)):
277
- op = op_list[i]
278
- is_user_defined_operator = '(' in op
 
 
 
 
279
 
280
- if is_user_defined_operator:
281
- def_hyperparams += op + "\n"
282
- # Cut off from the first non-alphanumeric char:
283
- first_non_char = [
284
- j for j in range(len(op))
285
- if not (op[j].isalpha() or op[j].isdigit())][0]
286
- function_name = op[:first_non_char]
287
- op_list[i] = function_name
288
 
289
- #arbitrary complexity by default
290
- for op in unary_operators:
291
- if op not in constraints:
292
- constraints[op] = -1
293
- for op in binary_operators:
294
- if op not in constraints:
295
- constraints[op] = (-1, -1)
296
- if op in ['plus', 'sub']:
297
- if constraints[op][0] != constraints[op][1]:
298
- raise NotImplementedError("You need equal constraints on both sides for - and *, due to simplification strategies.")
299
- elif op == 'mult':
300
- # Make sure the complex expression is in the left side.
301
- if constraints[op][0] == -1:
302
- continue
303
- elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
304
- constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
305
 
306
- constraints_str = "const una_constraints = ["
307
- first = True
308
- for op in unary_operators:
309
- val = constraints[op]
310
- if not first:
311
- constraints_str += ", "
312
- constraints_str += f"{val:d}"
313
- first = False
314
 
315
- constraints_str += """]
316
- const bin_constraints = ["""
317
 
318
- first = True
319
- for op in binary_operators:
320
- tup = constraints[op]
321
- if not first:
322
- constraints_str += ", "
323
- constraints_str += f"({tup[0]:d}, {tup[1]:d})"
324
- first = False
325
- constraints_str += "]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
 
 
 
 
 
 
 
 
327
  def_hyperparams += f"""include("{_escape_filename(operator_filename)}")
328
  {constraints_str}
329
  const binops = {'[' + ', '.join(binary_operators) + ']'}
@@ -362,7 +400,6 @@ const warmupMaxsize = {warmupMaxsize:d}
362
  const limitPowComplexity = {"true" if limitPowComplexity else "false"}
363
  const useFrequency = {"true" if useFrequency else "false"}
364
  """
365
-
366
  op_runner = ""
367
  if len(binary_operators) > 0:
368
  op_runner += """
@@ -373,14 +410,13 @@ const useFrequency = {"true" if useFrequency else "false"}
373
  end"""
374
  for i in range(1, len(binary_operators)):
375
  op_runner += f"""
376
- elseif i === {i+1}
377
  @inbounds @simd for j=1:clen
378
  x[j] = {binary_operators[i]}(x[j], y[j])
379
  end"""
380
  op_runner += """
381
  end
382
  end"""
383
-
384
  if len(unary_operators) > 0:
385
  op_runner += """
386
  @inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
@@ -390,85 +426,160 @@ end"""
390
  end"""
391
  for i in range(1, len(unary_operators)):
392
  op_runner += f"""
393
- elseif i === {i+1}
394
  @inbounds @simd for j=1:clen
395
  x[j] = {unary_operators[i]}(x[j])
396
  end"""
397
  op_runner += """
398
  end
399
  end"""
400
-
401
  def_hyperparams += op_runner
 
 
 
 
402
 
403
- def_datasets = """using DelimitedFiles"""
404
-
405
- np.savetxt(X_filename, X, delimiter=',')
406
- np.savetxt(y_filename, y, delimiter=',')
407
- if weights is not None:
408
- np.savetxt(weights_filename, weights, delimiter=',')
409
 
410
- def_datasets += f"""
411
- const X = readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')
412
- const y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- if weights is not None:
415
- def_datasets += f"""
416
- const weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')"""
417
 
418
- if use_custom_variable_names:
419
- def_hyperparams += f"""
420
- const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
- with open(hyperparam_filename, 'w') as f:
423
- print(def_hyperparams, file=f)
424
 
425
- with open(dataset_filename, 'w') as f:
426
- print(def_datasets, file=f)
 
 
 
 
427
 
428
- with open(runfile_filename, 'w') as f:
429
- print(f'@everywhere include("{_escape_filename(hyperparam_filename)}")', file=f)
430
- print(f'@everywhere include("{_escape_filename(dataset_filename)}")', file=f)
431
- print(f'@everywhere include("{_escape_filename(pkg_filename)}")', file=f)
432
- print(f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})', file=f)
433
- print(f'rmprocs(nprocs)', file=f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
 
 
 
435
 
436
- command = [
437
- f'julia', f'-O{julia_optimization:d}',
438
- f'-p', f'{procs}',
439
- str(runfile_filename),
440
- ]
441
- if timeout is not None:
442
- command = [f'timeout', f'{timeout}'] + command
443
 
444
- global global_n_features
445
- global global_equation_file
446
- global global_variable_names
447
- global global_extra_sympy_mappings
 
 
 
 
 
 
 
 
 
 
 
 
 
448
 
449
- global_n_features = X.shape[1]
450
- global_equation_file = equation_file
451
- global_variable_names = variable_names
452
- global_extra_sympy_mappings = extra_sympy_mappings
 
 
 
 
 
 
 
 
 
 
 
453
 
454
- print("Running on", ' '.join(command))
455
- process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=1)
456
- try:
457
- while True:
458
- line = process.stdout.readline()
459
- if not line: break
460
- print(line.decode('utf-8').replace('\n', ''))
461
 
462
- process.stdout.close()
463
- process.wait()
464
- except KeyboardInterrupt:
465
- print("Killing process... will return when done.")
466
- process.kill()
 
 
 
 
 
 
467
 
468
- if delete_tempfiles:
469
- shutil.rmtree(tmpdir)
470
 
471
- return get_hof()
 
 
 
 
472
 
473
 
474
  def run_feature_selection(X, y, select_k_features):
@@ -485,7 +596,7 @@ def run_feature_selection(X, y, select_k_features):
485
  max_features=select_k_features, prefit=True)
486
  return selector.get_support(indices=True)
487
 
488
- def get_hof(equation_file=None, n_features=None, variable_names=None, extra_sympy_mappings=None):
489
  """Get the equations from a hall of fame file. If no arguments
490
  entered, the ones used previously from a call to PySR will be used."""
491
 
 
192
  (as strings).
193
 
194
  """
195
+ _raise_depreciation_errors(limitPowComplexity, threads)
 
 
 
 
 
 
 
 
196
 
197
  if isinstance(X, pd.DataFrame):
198
  variable_names = list(X.columns)
 
203
  if len(X.shape) == 1:
204
  X = X[:, None]
205
 
206
+ _check_assertions(X, binary_operators, unary_operators,
207
+ use_custom_variable_names, variable_names, weights, y)
 
 
 
 
 
 
 
 
208
 
209
 
210
  if len(X) > 10000 and not batching:
211
  warnings.warn("Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed.")
212
 
213
+ X, variable_names = _handle_feature_selection(
214
+ X, select_k_features,
215
+ use_custom_variable_names, variable_names, y
216
+ )
 
 
 
217
 
218
+ if maxdepth is None:
219
+ maxdepth = maxsize
220
+ if equation_file is None:
221
+ date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
222
+ equation_file = 'hall_of_fame_' + date_time + '.csv'
223
  if populations is None:
224
  populations = procs
225
+ if isinstance(binary_operators, str):
226
+ binary_operators = [binary_operators]
227
+ if isinstance(unary_operators, str):
228
+ unary_operators = [unary_operators]
229
+ if X is None:
230
+ X, y = _using_test_input(X, test, y)
231
+
232
+ kwargs = dict(X=X, y=y, weights=weights,
233
+ alpha=alpha, annealing=annealing, batchSize=batchSize,
234
+ batching=batching, binary_operators=binary_operators,
235
+ equation_file=equation_file, fast_cycle=fast_cycle,
236
+ fractionReplaced=fractionReplaced,
237
+ ncyclesperiteration=ncyclesperiteration,
238
+ niterations=niterations, npop=npop,
239
+ topn=topn, verbosity=verbosity,
240
+ julia_optimization=julia_optimization, timeout=timeout,
241
+ fractionReplacedHof=fractionReplacedHof,
242
+ hofMigration=hofMigration,
243
+ limitPowComplexity=limitPowComplexity, maxdepth=maxdepth,
244
+ maxsize=maxsize, migration=migration, nrestarts=nrestarts,
245
+ parsimony=parsimony, perturbationFactor=perturbationFactor,
246
+ populations=populations, procs=procs,
247
+ shouldOptimizeConstants=shouldOptimizeConstants,
248
+ unary_operators=unary_operators, useFrequency=useFrequency,
249
+ use_custom_variable_names=use_custom_variable_names,
250
+ variable_names=variable_names, warmupMaxsize=warmupMaxsize,
251
+ weightAddNode=weightAddNode,
252
+ weightDeleteNode=weightDeleteNode,
253
+ weightDoNothing=weightDoNothing,
254
+ weightInsertNode=weightInsertNode,
255
+ weightMutateConstant=weightMutateConstant,
256
+ weightMutateOperator=weightMutateOperator,
257
+ weightRandomize=weightRandomize,
258
+ weightSimplify=weightSimplify,
259
+ constraints=constraints,
260
+ extra_sympy_mappings=extra_sympy_mappings)
261
+
262
+ kwargs = {**_set_paths(tempdir), **kwargs}
263
+
264
+ kwargs['def_hyperparams'] = _metaprogram_fast_operator(**kwargs)
265
+
266
+ _handle_constraints(**kwargs)
267
+
268
+ kwargs['constraints_str'] = _make_constraints_str(**kwargs)
269
+ kwargs['def_hyperparams'] = _make_hyperparams_julia_str(**kwargs)
270
+ kwargs['def_auxiliary'] = _make_auxiliary_julia_str(**kwargs)
271
+ kwargs['def_datasets'] = _make_datasets_julia_str(**kwargs)
272
+
273
+ _create_julia_files(**kwargs)
274
+ _final_pysr_process(**kwargs)
275
+ _set_globals(**kwargs)
276
 
277
+ if delete_tempfiles:
278
+ shutil.rmtree(kwargs['tmpdir'])
279
 
280
+ return get_hof(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
 
 
 
 
282
 
283
+ def _make_auxiliary_julia_str(julia_auxiliary_filenames, **kwargs):
284
+ def_auxiliary = '\n'.join([
285
+ f"""include("{_escape_filename(aux_fname)}")""" for aux_fname in julia_auxiliary_filenames
286
+ ])
287
+ return def_auxiliary
 
 
288
 
 
289
 
290
+ def _set_globals(X, equation_file, extra_sympy_mappings, variable_names, **kwargs):
291
+ global global_n_features
292
+ global global_equation_file
293
+ global global_variable_names
294
+ global global_extra_sympy_mappings
295
+ global_n_features = X.shape[1]
296
+ global_equation_file = equation_file
297
+ global_variable_names = variable_names
298
+ global_extra_sympy_mappings = extra_sympy_mappings
299
 
 
 
 
 
 
 
 
 
300
 
301
+ def _final_pysr_process(julia_optimization, procs, runfile_filename, timeout, **kwargs):
302
+ command = [
303
+ f'julia', f'-O{julia_optimization:d}',
304
+ f'-p', f'{procs}',
305
+ str(runfile_filename),
306
+ ]
307
+ if timeout is not None:
308
+ command = [f'timeout', f'{timeout}'] + command
309
+ print("Running on", ' '.join(command))
310
+ process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=1)
311
+ try:
312
+ while True:
313
+ line = process.stdout.readline()
314
+ if not line: break
315
+ print(line.decode('utf-8').replace('\n', ''))
 
316
 
317
+ process.stdout.close()
318
+ process.wait()
319
+ except KeyboardInterrupt:
320
+ print("Killing process... will return when done.")
321
+ process.kill()
 
 
 
322
 
 
 
323
 
324
+ def _create_julia_files(auxiliary_filename, dataset_filename, def_auxiliary, def_datasets, def_hyperparams, fractionReplaced, hyperparam_filename,
325
+ ncyclesperiteration, niterations, npop, pkg_filename, runfile_filename, topn, verbosity, **kwargs):
326
+ with open(hyperparam_filename, 'w') as f:
327
+ print(def_hyperparams, file=f)
328
+ with open(dataset_filename, 'w') as f:
329
+ print(def_datasets, file=f)
330
+ with open(auxiliary_filename, 'w') as f:
331
+ print(def_auxiliary, file=f)
332
+ with open(runfile_filename, 'w') as f:
333
+ print(f'@everywhere include("{_escape_filename(hyperparam_filename)}")', file=f)
334
+ print(f'@everywhere include("{_escape_filename(dataset_filename)}")', file=f)
335
+ print(f'@everywhere include("{_escape_filename(auxiliary_filename)}")', file=f)
336
+ print(f'@everywhere include("{_escape_filename(pkg_filename)}")', file=f)
337
+ print(
338
+ f'fullRun({niterations:d}, npop={npop:d}, ncyclesperiteration={ncyclesperiteration:d}, fractionReplaced={fractionReplaced:f}f0, verbosity=round(Int32, {verbosity:f}), topn={topn:d})',
339
+ file=f)
340
+ print(f'rmprocs(nprocs)', file=f)
341
+
342
+
343
+ def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
344
+ def_datasets = """using DelimitedFiles"""
345
+ np.savetxt(X_filename, X, delimiter=',')
346
+ np.savetxt(y_filename, y, delimiter=',')
347
+ if weights is not None:
348
+ np.savetxt(weights_filename, weights, delimiter=',')
349
+ def_datasets += f"""
350
+ const X = readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')
351
+ const y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')"""
352
+ if weights is not None:
353
+ def_datasets += f"""
354
+ const weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')"""
355
+ return def_datasets
356
 
357
+
358
+ def _make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary_operators, constraints_str,
359
+ def_hyperparams, equation_file, fast_cycle, fractionReplacedHof, hofMigration,
360
+ limitPowComplexity, maxdepth, maxsize, migration, nrestarts, operator_filename,
361
+ parsimony, perturbationFactor, populations, procs, shouldOptimizeConstants,
362
+ unary_operators, useFrequency, use_custom_variable_names, variable_names, warmupMaxsize, weightAddNode,
363
+ weightDeleteNode, weightDoNothing, weightInsertNode, weightMutateConstant,
364
+ weightMutateOperator, weightRandomize, weightSimplify, weights, **kwargs):
365
  def_hyperparams += f"""include("{_escape_filename(operator_filename)}")
366
  {constraints_str}
367
  const binops = {'[' + ', '.join(binary_operators) + ']'}
 
400
  const limitPowComplexity = {"true" if limitPowComplexity else "false"}
401
  const useFrequency = {"true" if useFrequency else "false"}
402
  """
 
403
  op_runner = ""
404
  if len(binary_operators) > 0:
405
  op_runner += """
 
410
  end"""
411
  for i in range(1, len(binary_operators)):
412
  op_runner += f"""
413
+ elseif i === {i + 1}
414
  @inbounds @simd for j=1:clen
415
  x[j] = {binary_operators[i]}(x[j], y[j])
416
  end"""
417
  op_runner += """
418
  end
419
  end"""
 
420
  if len(unary_operators) > 0:
421
  op_runner += """
422
  @inline function UNAOP!(x::Array{Float32, 1}, i::Int, clen::Int)
 
426
  end"""
427
  for i in range(1, len(unary_operators)):
428
  op_runner += f"""
429
+ elseif i === {i + 1}
430
  @inbounds @simd for j=1:clen
431
  x[j] = {unary_operators[i]}(x[j])
432
  end"""
433
  op_runner += """
434
  end
435
  end"""
 
436
  def_hyperparams += op_runner
437
+ if use_custom_variable_names:
438
+ def_hyperparams += f"""
439
+ const varMap = {'["' + '", "'.join(variable_names) + '"]'}"""
440
+ return def_hyperparams
441
 
 
 
 
 
 
 
442
 
443
+ def _make_constraints_str(binary_operators, constraints, unary_operators, **kwargs):
444
+ constraints_str = "const una_constraints = ["
445
+ first = True
446
+ for op in unary_operators:
447
+ val = constraints[op]
448
+ if not first:
449
+ constraints_str += ", "
450
+ constraints_str += f"{val:d}"
451
+ first = False
452
+ constraints_str += """]
453
+ const bin_constraints = ["""
454
+ first = True
455
+ for op in binary_operators:
456
+ tup = constraints[op]
457
+ if not first:
458
+ constraints_str += ", "
459
+ constraints_str += f"({tup[0]:d}, {tup[1]:d})"
460
+ first = False
461
+ constraints_str += "]"
462
+ return constraints_str
463
 
 
 
 
464
 
465
+ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs):
466
+ for op in unary_operators:
467
+ if op not in constraints:
468
+ constraints[op] = -1
469
+ for op in binary_operators:
470
+ if op not in constraints:
471
+ constraints[op] = (-1, -1)
472
+ if op in ['plus', 'sub']:
473
+ if constraints[op][0] != constraints[op][1]:
474
+ raise NotImplementedError(
475
+ "You need equal constraints on both sides for - and *, due to simplification strategies.")
476
+ elif op == 'mult':
477
+ # Make sure the complex expression is in the left side.
478
+ if constraints[op][0] == -1:
479
+ continue
480
+ elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
481
+ constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
482
 
 
 
483
 
484
+ def _metaprogram_fast_operator(binary_operators, unary_operators, **kwargs):
485
+ def_hyperparams = ""
486
+ for op_list in [binary_operators, unary_operators]:
487
+ for i in range(len(op_list)):
488
+ op = op_list[i]
489
+ is_user_defined_operator = '(' in op
490
 
491
+ if is_user_defined_operator:
492
+ def_hyperparams += op + "\n"
493
+ # Cut off from the first non-alphanumeric char:
494
+ first_non_char = [
495
+ j for j in range(len(op))
496
+ if not (op[j].isalpha() or op[j].isdigit())][0]
497
+ function_name = op[:first_non_char]
498
+ op_list[i] = function_name
499
+ return def_hyperparams
500
+
501
+
502
+ def _using_test_input(X, test, y):
503
+ if test == 'simple1':
504
+ eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
505
+ elif test == 'simple2':
506
+ eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**3.5 + 1/(np.abs(X[:, 0])+1)"
507
+ elif test == 'simple3':
508
+ eval_str = "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)"
509
+ elif test == 'simple4':
510
+ eval_str = "1.0 + 3*X[:, 0]**2 - 0.5*X[:, 0]**3 + 0.1*X[:, 0]**4"
511
+ elif test == 'simple5':
512
+ eval_str = "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)"
513
+ X = np.random.randn(100, 5) * 3
514
+ y = eval(eval_str)
515
+ print("Running on", eval_str)
516
+ return X, y
517
+
518
+
519
+ def _handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y):
520
+ if select_k_features is not None:
521
+ selection = run_feature_selection(X, y, select_k_features)
522
+ print(f"Using features {selection}")
523
+ X = X[:, selection]
524
 
525
+ if use_custom_variable_names:
526
+ variable_names = [variable_names[selection[i]] for i in range(len(selection))]
527
+ return X, variable_names
528
 
 
 
 
 
 
 
 
529
 
530
+ def _set_paths(tempdir):
531
+ # System-independent paths
532
+ pkg_directory = Path(__file__).parents[1] / 'julia'
533
+ pkg_filename = pkg_directory / "sr.jl"
534
+ operator_filename = pkg_directory / "Operators.jl"
535
+ julia_auxiliaries = [
536
+ "Equation.jl", "ProgramConstants.jl",
537
+ "LossFunctions.jl", "Utils.jl", "EvaluateEquation.jl",
538
+ "MutationFunctions.jl", "SimplifyEquation.jl", "PopMember.jl",
539
+ "HallOfFame.jl", "CheckConstraints.jl", "Mutate.jl",
540
+ "Population.jl", "RegularizedEvolution.jl", "SingleIteration.jl",
541
+ "ConstantOptimization.jl"
542
+ ]
543
+ julia_auxiliary_filenames = [
544
+ pkg_directory / fname
545
+ for fname in julia_auxiliaries
546
+ ]
547
 
548
+ tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
549
+ hyperparam_filename = tmpdir / f'hyperparams.jl'
550
+ dataset_filename = tmpdir / f'dataset.jl'
551
+ auxiliary_filename = tmpdir / f'auxiliary.jl'
552
+ runfile_filename = tmpdir / f'runfile.jl'
553
+ X_filename = tmpdir / "X.csv"
554
+ y_filename = tmpdir / "y.csv"
555
+ weights_filename = tmpdir / "weights.csv"
556
+ return dict(auxiliary_filename=auxiliary_filename, X_filename=X_filename,
557
+ dataset_filename=dataset_filename,
558
+ hyperparam_filename=hyperparam_filename,
559
+ julia_auxiliary_filenames=julia_auxiliary_filenames,
560
+ operator_filename=operator_filename, pkg_filename=pkg_filename,
561
+ runfile_filename=runfile_filename, tmpdir=tmpdir,
562
+ weights_filename=weights_filename, y_filename=y_filename)
563
 
 
 
 
 
 
 
 
564
 
565
+ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
566
+ # Check for potential errors before they happen
567
+ assert len(unary_operators) + len(binary_operators) > 0
568
+ assert len(X.shape) == 2
569
+ assert len(y.shape) == 1
570
+ assert X.shape[0] == y.shape[0]
571
+ if weights is not None:
572
+ assert len(weights.shape) == 1
573
+ assert X.shape[0] == weights.shape[0]
574
+ if use_custom_variable_names:
575
+ assert len(variable_names) == X.shape[1]
576
 
 
 
577
 
578
+ def _raise_depreciation_errors(limitPowComplexity, threads):
579
+ if threads is not None:
580
+ raise ValueError("The threads kwarg is deprecated. Use procs.")
581
+ if limitPowComplexity:
582
+ raise ValueError("The limitPowComplexity kwarg is deprecated. Use constraints.")
583
 
584
 
585
  def run_feature_selection(X, y, select_k_features):
 
596
  max_features=select_k_features, prefit=True)
597
  return selector.get_support(indices=True)
598
 
599
+ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_sympy_mappings=None, **kwargs):
600
  """Get the equations from a hall of fame file. If no arguments
601
  entered, the ones used previously from a call to PySR will be used."""
602