MilesCranmer commited on
Commit
688106d
0 Parent(s):

First commit

Browse files
Files changed (1) hide show
  1. eureqa.jl +440 -0
eureqa.jl ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using ProgressBars
2
+
3
+ # Define allowed operators
4
+ plus(x::Float64, y::Float64) = x+y
5
+ mult(x::Float64, y::Float64) = x*y;
6
+
7
+ # (Apparently using const for globals helps speed)
8
+ const binops = [plus, mult]
9
+ const unaops = [sin, cos, exp];
10
+
11
+ const nvar = 5;
12
+ const X = rand(100, nvar);
13
+
14
+ # Here is the function we want to learn (x2^2 + cos(x3) + 5)
15
+ const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3]) .+ 5.0;
16
+
17
+ # How much to punish complexity
18
+ const parsimony = 0.01
19
+ # How much to scale temperature by (T between 0 and 1)
20
+ const alpha = 10.0
21
+
22
+
23
+
24
+
25
+ id = (x,) -> x
26
+ const nuna = size(unaops)[1]
27
+ const nbin = size(binops)[1]
28
+ const nops = nuna + nbin
29
+
30
+ # Define a serialization format for the symbolic equations:
31
+ mutable struct Node
32
+ #Holds operators, variables, constants in a tree
33
+ degree::Int #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
34
+ val::Union{Float64, Int} #Either const value, or enumerates variable
35
+ constant::Bool #false if variable
36
+ op::Function #enumerates operator (for degree=1,2)
37
+ l::Union{Node, Nothing}
38
+ r::Union{Node, Nothing}
39
+
40
+ Node(val::Float64) = new(0, val, true, id, nothing, nothing)
41
+ Node(val::Int) = new(0, val, false, id, nothing, nothing)
42
+ Node(op, l::Node) = new(1, 0.0, false, op, l, nothing)
43
+ Node(op, l::Union{Float64, Int}) = new(1, 0.0, false, op, Node(l), nothing)
44
+ Node(op, l::Node, r::Node) = new(2, 0.0, false, op, l, r)
45
+
46
+ #Allow to pass the leaf value without additional node call:
47
+ Node(op, l::Union{Float64, Int}, r::Node) = new(2, 0.0, false, op, Node(l), r)
48
+ Node(op, l::Node, r::Union{Float64, Int}) = new(2, 0.0, false, op, l, Node(r))
49
+ Node(op, l::Union{Float64, Int}, r::Union{Float64, Int}) = new(2, 0.0, false, op, Node(l), Node(r))
50
+ end
51
+
52
+ # Evaluate a symbolic equation:
53
+ function evalTree(tree::Node, x::Array{Float64, 1}=Float64[])::Float64
54
+ if tree.degree == 0
55
+ if tree.constant
56
+ return tree.val
57
+ else
58
+ return x[tree.val]
59
+ end
60
+ elseif tree.degree == 1
61
+ return tree.op(evalTree(tree.l, x))
62
+ else
63
+ return tree.op(evalTree(tree.l, x), evalTree(tree.r, x))
64
+ end
65
+ end
66
+
67
+ # Count the operators, constants, variables in an equation
68
+ function countNodes(tree::Node)::Int
69
+ if tree.degree == 0
70
+ return 1
71
+ elseif tree.degree == 1
72
+ return 1 + countNodes(tree.l)
73
+ else
74
+ return 1 + countNodes(tree.l) + countNodes(tree.r)
75
+ end
76
+ end
77
+
78
+ # Convert an equation to a string
79
+ function stringTree(tree::Node)::String
80
+ if tree.degree == 0
81
+ if tree.constant
82
+ return string(tree.val)
83
+ else
84
+ return "x$(tree.val)"
85
+ end
86
+ elseif tree.degree == 1
87
+ return "$(tree.op)($(stringTree(tree.l)))"
88
+ else
89
+ return "$(tree.op)($(stringTree(tree.l)), $(stringTree(tree.r)))"
90
+ end
91
+ end
92
+
93
+ # Print an equation
94
+ function printTree(tree::Node)
95
+ println(stringTree(tree))
96
+ end
97
+
98
+ # Return a random node from the tree
99
+ function randomNode(tree::Node)::Node
100
+ if tree.degree == 0
101
+ return tree
102
+ end
103
+ a = countNodes(tree)
104
+ b = 0
105
+ c = 0
106
+ if tree.degree >= 1
107
+ b = countNodes(tree.l)
108
+ end
109
+ if tree.degree == 2
110
+ c = countNodes(tree.r)
111
+ end
112
+
113
+ i = rand(1:1+b+c)
114
+ if i <= b
115
+ return randomNode(tree.l)
116
+ elseif i == b + 1
117
+ return tree
118
+ end
119
+
120
+ return randomNode(tree.r)
121
+ end
122
+
123
+ # Count the number of unary operators in the equation
124
+ function countUnaryOperators(tree::Node)::Int
125
+ if tree.degree == 0
126
+ return 0
127
+ elseif tree.degree == 1
128
+ return 1 + countUnaryOperators(tree.l)
129
+ else
130
+ return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
131
+ end
132
+ end
133
+
134
+ # Count the number of binary operators in the equation
135
+ function countBinaryOperators(tree::Node)::Int
136
+ if tree.degree == 0
137
+ return 0
138
+ elseif tree.degree == 1
139
+ return 0 + countBinaryOperators(tree.l)
140
+ else
141
+ return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
142
+ end
143
+ end
144
+
145
+ # Count the number of operators in the equation
146
+ function countOperators(tree::Node)::Int
147
+ return countUnaryOperators(tree) + countBinaryOperators(tree)
148
+ end
149
+
150
+ # Randomly convert an operator into another one (binary->binary;
151
+ # unary->unary)
152
+ function mutateOperator(tree::Node)::Node
153
+ if countOperators(tree) == 0
154
+ return tree
155
+ end
156
+ node = randomNode(tree)
157
+ while node.degree == 0
158
+ node = randomNode(tree)
159
+ end
160
+ if node.degree == 1
161
+ node.op = unaops[rand(1:length(unaops))]
162
+ else
163
+ node.op = binops[rand(1:length(binops))]
164
+ end
165
+ return tree
166
+ end
167
+
168
+ # Count the number of constants in an equation
169
+ function countConstants(tree::Node)::Int
170
+ if tree.degree == 0
171
+ return convert(Int, tree.constant)
172
+ elseif tree.degree == 1
173
+ return 0 + countConstants(tree.l)
174
+ else
175
+ return 0 + countConstants(tree.l) + countConstants(tree.r)
176
+ end
177
+ end
178
+
179
+ # Randomly perturb a constant
180
+ function mutateConstant(
181
+ tree::Node, T::Float64,
182
+ probNegate::Float64=0.01)::Node
183
+ # T is between 0 and 1.
184
+
185
+ if countConstants(tree) == 0
186
+ return tree
187
+ end
188
+ node = randomNode(tree)
189
+ while node.degree != 0 || node.constant == false
190
+ node = randomNode(tree)
191
+ end
192
+
193
+ maxChange = T + 1.0
194
+ factor = maxChange^rand()
195
+ makeConstBigger = rand() > 0.5
196
+
197
+ if makeConstBigger
198
+ node.val *= factor
199
+ else
200
+ node.val /= factor
201
+ end
202
+
203
+ if rand() > probNegate
204
+ node.val *= -1
205
+ end
206
+
207
+ return tree
208
+ end
209
+
210
+ # Evaluate an equation over an array of datapoints
211
+ function evalTreeArray(
212
+ tree::Node,
213
+ x::Array{Float64, 2})::Array{Float64, 1}
214
+ return mapslices(
215
+ (cx,) -> evalTree(tree, cx),
216
+ x,
217
+ dims=[2]
218
+ )[:, 1]
219
+ end
220
+
221
+ # Sum of square error between two arrays
222
+ function SSE(x::Array{Float64}, y::Array{Float64})::Float64
223
+ return sum(((cx,)->cx^2).(x - y))
224
+ end
225
+
226
+ # Mean of square error between two arrays
227
+ function MSE(x::Array{Float64}, y::Array{Float64})::Float64
228
+ return SSE(x, y)/size(x)[1]
229
+ end
230
+
231
+ # Score an equation
232
+ function scoreFunc(
233
+ tree::Node,
234
+ X::Array{Float64, 2},
235
+ y::Array{Float64, 1},
236
+ parsimony::Float64=0.1)::Float64
237
+ return MSE(evalTreeArray(tree, X), y) + countNodes(tree)*parsimony
238
+ end
239
+
240
+ # Add a random unary/binary operation to the end of a tree
241
+ function appendRandomOp(tree::Node)::Node
242
+ node = randomNode(tree)
243
+ while node.degree != 0
244
+ node = randomNode(tree)
245
+ end
246
+
247
+ choice = rand()
248
+ makeNewBinOp = choice < nbin/nops
249
+ if rand() > 0.5
250
+ left = randn()
251
+ else
252
+ left = rand(1:nvar)
253
+ end
254
+ if rand() > 0.5
255
+ right = randn()
256
+ else
257
+ right = rand(1:nvar)
258
+ end
259
+
260
+ if makeNewBinOp
261
+ newnode = Node(
262
+ binops[rand(1:length(binops))],
263
+ left,
264
+ right
265
+ )
266
+ else
267
+ newnode = Node(
268
+ unaops[rand(1:length(unaops))],
269
+ left
270
+ )
271
+ end
272
+ node.l = newnode.l
273
+ node.r = newnode.r
274
+ node.op = newnode.op
275
+ node.degree = newnode.degree
276
+ node.val = newnode.val
277
+ node.constant = newnode.constant
278
+ return tree
279
+ end
280
+
281
+ # Select a random node, and replace it an the subtree
282
+ # with a variable or constant
283
+ function deleteRandomOp(tree::Node)::Node
284
+ node = randomNode(tree)
285
+ # Can "delete" variable or constant too
286
+ if rand() > 0.5
287
+ val = randn()
288
+ else
289
+ val = rand(1:nvar)
290
+ end
291
+ newnode = Node(val)
292
+ node.l = newnode.l
293
+ node.r = newnode.r
294
+ node.op = newnode.op
295
+ node.degree = newnode.degree
296
+ node.val = newnode.val
297
+ node.constant = newnode.constant
298
+ return tree
299
+ end
300
+
301
+ # Go through one simulated annealing mutation cycle
302
+ # exp(-delta/T) defines probability of accepting a change
303
+ function iterate(
304
+ tree::Node, T::Float64,
305
+ X::Array{Float64, 2}, y::Array{Float64, 1},
306
+ alpha::Float64=1.0,
307
+ mult::Float64=0.1
308
+ )::Node
309
+ prev = deepcopy(tree)
310
+
311
+ mutationChoice = rand()
312
+ weights = [8, 1, 1, 1]
313
+ weights /= sum(weights)
314
+ cweights = cumsum(weights)
315
+
316
+ if mutationChoice < cweights[1]
317
+ tree = mutateConstant(tree, T)
318
+ elseif mutationChoice < cweights[2]
319
+ tree = mutateOperator(tree)
320
+ elseif mutationChoice < cweights[3]
321
+ tree = appendRandomOp(tree)
322
+ elseif mutationChoice < cweights[4]
323
+ tree = deleteRandomOp(tree)
324
+ end
325
+
326
+ try
327
+ beforeLoss = scoreFunc(prev, X, y, mult)
328
+ afterLoss = scoreFunc(tree, X, y, mult)
329
+ delta = afterLoss - beforeLoss
330
+ probChange = exp(-delta/(T*alpha))
331
+
332
+ if probChange > rand()
333
+ return tree
334
+ end
335
+
336
+ return prev
337
+ catch error
338
+ # Sometimes too many chained exp operators
339
+ if isa(error, DomainError)
340
+ return prev
341
+ else
342
+ throw(error)
343
+ end
344
+ end
345
+ end
346
+
347
+ # Create a random equation by appending random operators
348
+ function genRandomTree(length::Int)::Node
349
+ tree = Node(1.0)
350
+ for i=1:length
351
+ tree = appendRandomOp(tree)
352
+ end
353
+ return tree
354
+ end
355
+
356
+
357
+ # Define a member of population by equation, score, and age
358
+ mutable struct PopMember
359
+ tree::Node
360
+ score::Float64
361
+ birth::Float64
362
+
363
+ PopMember(t) = new(t, scoreFunc(t, X, y, parsimony), time()-1.6e9)
364
+ end
365
+
366
+ # A list of members of the population, with easy constructors,
367
+ # which allow for random generation of new populations
368
+ mutable struct Population
369
+ members::Array{PopMember, 1}
370
+ n::Int
371
+
372
+ Population(pop::Array{PopMember, 1}) = new(pop, size(pop)[1])
373
+ Population(npop::Int64) = new([PopMember(genRandomTree(3)) for i=1:npop], npop)
374
+ Population(npop::Int64, nlength::Int64) = new([PopMember(genRandomTree(nlength)) for i=1:npop], npop)
375
+
376
+ end
377
+
378
+ # Sample 10 random members of the population, and make a new one
379
+ function samplePop(pop::Population)::Population
380
+ idx = rand(1:pop.n, 10)
381
+ return Population(pop.members[idx])#Population(deepcopy(pop.members[idx]))
382
+ end
383
+
384
+ # Sample the population, and get the best member from that sample
385
+ function bestOfSample(pop::Population)::PopMember
386
+ sample = samplePop(pop)
387
+ best_idx = argmin([sample.members[member].score for member=1:sample.n])
388
+ return sample.members[best_idx]
389
+ end
390
+
391
+ # Mutate the best sampled member of the population
392
+ function iterateSample(pop::Population, T::Float64)::PopMember
393
+ allstar = bestOfSample(pop)
394
+ new = iterate(allstar.tree, T, X, y, alpha, parsimony)
395
+ allstar.tree = new
396
+ allstar.score = scoreFunc(new, X, y, parsimony)
397
+ allstar.birth = time() - 1.6e9
398
+ return allstar
399
+ end
400
+
401
+ # Pass through the population several times, replacing the oldest
402
+ # with the fittest of a small subsample
403
+ function regEvolCycle(pop::Population, T::Float64)::Population
404
+ for i=1:Int(pop.n/10)
405
+ baby = iterateSample(pop, T)
406
+ oldest = argmin([pop.members[member].birth for member=1:pop.n])
407
+ pop.members[oldest] = baby
408
+ end
409
+ return pop
410
+ end
411
+
412
+ # Cycle through regularized evolution many times,
413
+ # printing the fittest equation every 10% through
414
+ function run(ncycles::Int,
415
+ npop::Int=100,
416
+ annealing::Bool=false)::Population
417
+
418
+ allT = LinRange(1.0, 0.0, ncycles)
419
+ pop = Population(npop, 3)
420
+ bestScore = Inf
421
+ for iT in tqdm(1:size(allT)[1])
422
+ if annealing
423
+ pop = regEvolCycle(pop, allT[iT])
424
+ else
425
+ pop = regEvolCycle(pop, 0.0)
426
+ end
427
+ bestCurScoreIdx = argmin([pop.members[member].score for member=1:pop.n])
428
+ bestCurScore = pop.members[bestCurScoreIdx].score
429
+ if bestCurScore < bestScore
430
+ bestScore = bestCurScore
431
+ println(bestScore, " is the score for ", stringTree(pop.members[bestCurScoreIdx].tree))
432
+ end
433
+ end
434
+ return pop
435
+ end
436
+
437
+ println("Lets try to learn (x2^2 + cos(x3) + 5) using regularized evolution from scratch")
438
+ pop = run(10000, 1000, false);
439
+
440
+