MilesCranmer commited on
Commit
d3b42d5
1 Parent(s): 8d9fde2

Change default hyperparams; add skip option for mutation

Browse files
Files changed (3) hide show
  1. README.md +2 -2
  2. eureqa.jl +12 -7
  3. paralleleureqa.jl +4 -4
README.md CHANGED
@@ -20,8 +20,8 @@ You can change the dataset here:
20
  ```
21
  const nvar = 5;
22
  const X = rand(100, nvar);
23
- # Here is the function we want to learn (x2^2 + cos(x3) + 5)
24
- const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3]) .+ 5.0;
25
  ```
26
  by either loading in a dataset, or modifying the definition of `y`.
27
 
 
20
  ```
21
  const nvar = 5;
22
  const X = rand(100, nvar);
23
+ # Here is the function we want to learn (x2^2 + cos(x3))
24
+ const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3])
25
  ```
26
  by either loading in a dataset, or modifying the definition of `y`.
27
 
eureqa.jl CHANGED
@@ -9,21 +9,24 @@ const binops = [plus, mult]
9
  const unaops = [sin, cos, exp]
10
  ##########################
11
 
 
 
 
12
  const nvar = 5;
13
  # Here is the function we want to learn (x2^2 + cos(x3) + 5)
14
  #
15
  ##########################
16
  # # Dataset to learn
17
- const X = rand(100, nvar)
18
- const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3]) .+ 5.0;
19
  ##########################
20
 
21
  ##################
22
  # Hyperparameters
23
  # How much to punish complexity
24
- const parsimony = 0.01
25
  # How much to scale temperature by (T between 0 and 1)
26
- const alpha = 10.0
27
  const maxsize = 20
28
  ##################
29
 
@@ -315,7 +318,7 @@ function iterate(
315
  prev = deepcopy(tree)
316
 
317
  mutationChoice = rand()
318
- weights = [8, 1, 1, 1]
319
  weights /= sum(weights)
320
  cweights = cumsum(weights)
321
  n = countNodes(tree)
@@ -328,6 +331,8 @@ function iterate(
328
  tree = appendRandomOp(tree)
329
  elseif mutationChoice < cweights[4]
330
  tree = deleteRandomOp(tree)
 
 
331
  end
332
 
333
  try
@@ -384,7 +389,7 @@ end
384
 
385
  # Sample 10 random members of the population, and make a new one
386
  function samplePop(pop::Population)::Population
387
- idx = rand(1:pop.n, 10)
388
  return Population(pop.members[idx])#Population(deepcopy(pop.members[idx]))
389
  end
390
 
@@ -414,7 +419,7 @@ end
414
  # Pass through the population several times, replacing the oldest
415
  # with the fittest of a small subsample
416
  function regEvolCycle(pop::Population, T::Float64)::Population
417
- for i=1:Int(pop.n/10)
418
  baby = iterateSample(pop, T)
419
  #printTree(baby.tree)
420
  oldest = argmin([pop.members[member].birth for member=1:pop.n])
 
9
  const unaops = [sin, cos, exp]
10
  ##########################
11
 
12
+ # How many equations to search when replacing
13
+ const ns=10;
14
+
15
  const nvar = 5;
16
  # Here is the function we want to learn (x2^2 + cos(x3) + 5)
17
  #
18
  ##########################
19
  # # Dataset to learn
20
+ const X = randn(100, nvar)*2
21
+ const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3])
22
  ##########################
23
 
24
  ##################
25
  # Hyperparameters
26
  # How much to punish complexity
27
+ const parsimony = 1e-3
28
  # How much to scale temperature by (T between 0 and 1)
29
+ const alpha = 100.0
30
  const maxsize = 20
31
  ##################
32
 
 
318
  prev = deepcopy(tree)
319
 
320
  mutationChoice = rand()
321
+ weights = [8, 1, 1, 1, 2]
322
  weights /= sum(weights)
323
  cweights = cumsum(weights)
324
  n = countNodes(tree)
 
331
  tree = appendRandomOp(tree)
332
  elseif mutationChoice < cweights[4]
333
  tree = deleteRandomOp(tree)
334
+ else
335
+ tree = tree
336
  end
337
 
338
  try
 
389
 
390
  # Sample 10 random members of the population, and make a new one
391
  function samplePop(pop::Population)::Population
392
+ idx = rand(1:pop.n, ns)
393
  return Population(pop.members[idx])#Population(deepcopy(pop.members[idx]))
394
  end
395
 
 
419
  # Pass through the population several times, replacing the oldest
420
  # with the fittest of a small subsample
421
  function regEvolCycle(pop::Population, T::Float64)::Population
422
+ for i=1:Int(pop.n/ns)
423
  baby = iterateSample(pop, T)
424
  #printTree(baby.tree)
425
  oldest = argmin([pop.members[member].birth for member=1:pop.n])
paralleleureqa.jl CHANGED
@@ -1,12 +1,12 @@
1
  include("eureqa.jl")
2
 
3
- println("Lets try to learn (x2^2 + cos(x3) + 5) using regularized evolution from scratch")
4
  const nthreads = Threads.nthreads()
5
  println("Running with $nthreads threads")
6
- const npop = 100
7
  const annealing = true
8
- const niterations = 30
9
- const ncyclesperiteration = 10000
10
 
11
  # Generate random initial populations
12
  allPops = [Population(npop, 3) for j=1:nthreads]
 
1
  include("eureqa.jl")
2
 
3
+ println("Lets try to learn (x2^2 + cos(x3)) using regularized evolution from scratch")
4
  const nthreads = Threads.nthreads()
5
  println("Running with $nthreads threads")
6
+ const npop = 1000
7
  const annealing = true
8
+ const niterations = 100
9
+ const ncyclesperiteration = 30000
10
 
11
  # Generate random initial populations
12
  allPops = [Population(npop, 3) for j=1:nthreads]