MilesCranmer commited on
Commit
41cab3e
1 Parent(s): 9d82fc3

Vectorize evaluation; get 5x speedup

Browse files
Files changed (2) hide show
  1. benchmarktimes.txt +1 -0
  2. eureqa.jl +32 -26
benchmarktimes.txt CHANGED
@@ -1 +1,2 @@
1
  The median time is TrialEstimate(3.136 s)
 
 
1
  The median time is TrialEstimate(3.136 s)
2
+ The median time is TrialEstimate(611.157 ms)
eureqa.jl CHANGED
@@ -44,17 +44,17 @@ mutable struct Node
44
  #Holds operators, variables, constants in a tree
45
  degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
46
  val::Union{Float32, Integer} #Either const value, or enumerates variable
47
- constant::Bool #false if variable
48
  op::Function #enumerates operator (for degree=1,2)
49
  l::Union{Node, Nothing}
50
  r::Union{Node, Nothing}
51
-
52
  Node(val::Float32) = new(0, val, true, id, nothing, nothing)
53
  Node(val::Integer) = new(0, val, false, id, nothing, nothing)
54
  Node(op, l::Node) = new(1, 0.0f0, false, op, l, nothing)
55
  Node(op, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
56
  Node(op, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
57
-
58
  #Allow to pass the leaf value without additional node call:
59
  Node(op, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
60
  Node(op, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
@@ -121,15 +121,15 @@ function randomNode(tree::Node)::Node
121
  if tree.degree == 2
122
  c = countNodes(tree.r)
123
  end
124
-
125
  i = rand(1:1+b+c)
126
  if i <= b
127
  return randomNode(tree.l)
128
  elseif i == b + 1
129
  return tree
130
  end
131
-
132
- return randomNode(tree.r)
133
  end
134
 
135
  # Count the number of unary operators in the equation
@@ -193,7 +193,7 @@ function mutateConstant(
193
  tree::Node, T::Float32,
194
  probNegate::Float32=0.01f0)::Node
195
  # T is between 0 and 1.
196
-
197
  if countConstants(tree) == 0
198
  return tree
199
  end
@@ -201,22 +201,22 @@ function mutateConstant(
201
  while node.degree != 0 || node.constant == false
202
  node = randomNode(tree)
203
  end
204
-
205
  bottom = 0.1f0
206
  maxChange = T + 1.0f0 + bottom
207
  factor = maxChange^Float32(rand())
208
  makeConstBigger = rand() > 0.5
209
-
210
- if makeConstBigger
211
  node.val *= factor
212
  else
213
  node.val /= factor
214
  end
215
-
216
  if rand() > probNegate
217
  node.val *= -1
218
  end
219
-
220
  return tree
221
  end
222
 
@@ -224,11 +224,18 @@ end
224
  function evalTreeArray(
225
  tree::Node,
226
  x::Array{Float32, 2})::Array{Float32, 1}
227
- return mapslices(
228
- (cx,) -> evalTree(tree, cx),
229
- x,
230
- dims=[2]
231
- )[:, 1]
 
 
 
 
 
 
 
232
  end
233
 
234
  # Sum of square error between two arrays
@@ -260,7 +267,7 @@ function appendRandomOp(tree::Node)::Node
260
  while node.degree != 0
261
  node = randomNode(tree)
262
  end
263
-
264
  choice = rand()
265
  makeNewBinOp = choice < nbin/nops
266
  if rand() > 0.5
@@ -273,7 +280,7 @@ function appendRandomOp(tree::Node)::Node
273
  else
274
  right = rand(1:nvar)
275
  end
276
-
277
  if makeNewBinOp
278
  newnode = Node(
279
  binops[rand(1:length(binops))],
@@ -324,14 +331,14 @@ function iterate(
324
  mult::Float32=0.1f0
325
  )::Node
326
  prev = deepcopy(tree)
327
-
328
  mutationChoice = rand()
329
  weight_for_constant = min(8, countConstants(tree))
330
  weights = [weight_for_constant, 1, 1, 1, 2]
331
  weights /= sum(weights)
332
  cweights = cumsum(weights)
333
  n = countNodes(tree)
334
-
335
  if mutationChoice < cweights[1]
336
  tree = mutateConstant(tree, T)
337
  elseif mutationChoice < cweights[2]
@@ -343,7 +350,7 @@ function iterate(
343
  else
344
  tree = tree
345
  end
346
-
347
  try
348
  beforeLoss = scoreFunc(prev, X, y, parsimony=mult)
349
  afterLoss = scoreFunc(tree, X, y, parsimony=mult)
@@ -381,7 +388,7 @@ mutable struct PopMember
381
  tree::Node
382
  score::Float32
383
  birth::Int32
384
-
385
  PopMember(t) = new(t, scoreFunc(t, X, y, parsimony=parsimony), round(Int32, 1e3*(time()-1.6e9))
386
  )
387
  end
@@ -391,11 +398,11 @@ end
391
  mutable struct Population
392
  members::Array{PopMember, 1}
393
  n::Integer
394
-
395
  Population(pop::Array{PopMember, 1}) = new(pop, size(pop)[1])
396
  Population(npop::Integer) = new([PopMember(genRandomTree(3)) for i=1:npop], npop)
397
  Population(npop::Integer, nlength::Integer) = new([PopMember(genRandomTree(nlength)) for i=1:npop], npop)
398
-
399
  end
400
 
401
  # Sample 10 random members of the population, and make a new one
@@ -465,4 +472,3 @@ function run(
465
  end
466
  return pop
467
  end
468
-
 
44
  #Holds operators, variables, constants in a tree
45
  degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
46
  val::Union{Float32, Integer} #Either const value, or enumerates variable
47
+ constant::Bool #false if variable
48
  op::Function #enumerates operator (for degree=1,2)
49
  l::Union{Node, Nothing}
50
  r::Union{Node, Nothing}
51
+
52
  Node(val::Float32) = new(0, val, true, id, nothing, nothing)
53
  Node(val::Integer) = new(0, val, false, id, nothing, nothing)
54
  Node(op, l::Node) = new(1, 0.0f0, false, op, l, nothing)
55
  Node(op, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
56
  Node(op, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
57
+
58
  #Allow to pass the leaf value without additional node call:
59
  Node(op, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
60
  Node(op, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
 
121
  if tree.degree == 2
122
  c = countNodes(tree.r)
123
  end
124
+
125
  i = rand(1:1+b+c)
126
  if i <= b
127
  return randomNode(tree.l)
128
  elseif i == b + 1
129
  return tree
130
  end
131
+
132
+ return randomNode(tree.r)
133
  end
134
 
135
  # Count the number of unary operators in the equation
 
193
  tree::Node, T::Float32,
194
  probNegate::Float32=0.01f0)::Node
195
  # T is between 0 and 1.
196
+
197
  if countConstants(tree) == 0
198
  return tree
199
  end
 
201
  while node.degree != 0 || node.constant == false
202
  node = randomNode(tree)
203
  end
204
+
205
  bottom = 0.1f0
206
  maxChange = T + 1.0f0 + bottom
207
  factor = maxChange^Float32(rand())
208
  makeConstBigger = rand() > 0.5
209
+
210
+ if makeConstBigger
211
  node.val *= factor
212
  else
213
  node.val /= factor
214
  end
215
+
216
  if rand() > probNegate
217
  node.val *= -1
218
  end
219
+
220
  return tree
221
  end
222
 
 
224
  function evalTreeArray(
225
  tree::Node,
226
  x::Array{Float32, 2})::Array{Float32, 1}
227
+ len = ones(size(x))[2]
228
+ if tree.degree == 0
229
+ if tree.constant
230
+ return ones(Float32, len) .* tree.val
231
+ else
232
+ return ones(Float32, len) .* x[:, tree.val]
233
+ end
234
+ elseif tree.degree == 1
235
+ return tree.op.(evalTree(tree.l, x))
236
+ else
237
+ return tree.op.(evalTree(tree.l, x), evalTree(tree.r, x))
238
+ end
239
  end
240
 
241
  # Sum of square error between two arrays
 
267
  while node.degree != 0
268
  node = randomNode(tree)
269
  end
270
+
271
  choice = rand()
272
  makeNewBinOp = choice < nbin/nops
273
  if rand() > 0.5
 
280
  else
281
  right = rand(1:nvar)
282
  end
283
+
284
  if makeNewBinOp
285
  newnode = Node(
286
  binops[rand(1:length(binops))],
 
331
  mult::Float32=0.1f0
332
  )::Node
333
  prev = deepcopy(tree)
334
+
335
  mutationChoice = rand()
336
  weight_for_constant = min(8, countConstants(tree))
337
  weights = [weight_for_constant, 1, 1, 1, 2]
338
  weights /= sum(weights)
339
  cweights = cumsum(weights)
340
  n = countNodes(tree)
341
+
342
  if mutationChoice < cweights[1]
343
  tree = mutateConstant(tree, T)
344
  elseif mutationChoice < cweights[2]
 
350
  else
351
  tree = tree
352
  end
353
+
354
  try
355
  beforeLoss = scoreFunc(prev, X, y, parsimony=mult)
356
  afterLoss = scoreFunc(tree, X, y, parsimony=mult)
 
388
  tree::Node
389
  score::Float32
390
  birth::Int32
391
+
392
  PopMember(t) = new(t, scoreFunc(t, X, y, parsimony=parsimony), round(Int32, 1e3*(time()-1.6e9))
393
  )
394
  end
 
398
  mutable struct Population
399
  members::Array{PopMember, 1}
400
  n::Integer
401
+
402
  Population(pop::Array{PopMember, 1}) = new(pop, size(pop)[1])
403
  Population(npop::Integer) = new([PopMember(genRandomTree(3)) for i=1:npop], npop)
404
  Population(npop::Integer, nlength::Integer) = new([PopMember(genRandomTree(nlength)) for i=1:npop], npop)
405
+
406
  end
407
 
408
  # Sample 10 random members of the population, and make a new one
 
472
  end
473
  return pop
474
  end