Dhananjay Ashok commited on
Commit
d7fdb9c
Β·
unverified Β·
2 Parent(s): 371f4fe e30b74d

Merge branch 'recover' into refactoring

Browse files
.gitignore CHANGED
@@ -8,8 +8,7 @@ trials*
8
  **/__pycache__
9
  build
10
  dist
11
- *.idea/
12
  *.vs/*
13
  *.pyproj
14
  *.sln
15
- pysr/.vs/
 
8
  **/__pycache__
9
  build
10
  dist
 
11
  *.vs/*
12
  *.pyproj
13
  *.sln
14
+ pysr/.vs/
README.md CHANGED
@@ -65,16 +65,11 @@ pip install pysr
65
  ```
66
 
67
  # Quickstart
68
- For a simple demo, start by navigating to the root directory and then in a terminal call
69
- ```
70
- >> python main.py
71
- ```
72
- <br>
73
- For a more hands on demo you can do the following (also in the root directory)
74
 
 
75
  ```python
76
  import numpy as np
77
- from pysr import pysr, best, get_hof
78
 
79
  # Dataset
80
  X = 2*np.random.randn(100, 5)
 
65
  ```
66
 
67
  # Quickstart
 
 
 
 
 
 
68
 
69
+ Here is some demo code (also found in `example.py`)
70
  ```python
71
  import numpy as np
72
+ from pysr import pysr, best
73
 
74
  # Dataset
75
  X = 2*np.random.randn(100, 5)
main.py β†’ example.py RENAMED
@@ -1,19 +1,17 @@
1
  import numpy as np
2
- from pysr import pysr, best, get_hof
3
- import time
4
 
5
  # Dataset
6
  X = 2*np.random.randn(100, 5)
7
  y = 2*np.cos(X[:, 3]) + X[:, 0]**2 - 2
8
 
9
-
10
  # Learn equations
11
- start = time.time()
12
  equations = pysr(X, y, niterations=5,
13
- binary_operators=["plus", "mult"],
14
- unary_operators=["cos", "exp", "sin"])
 
 
15
 
16
- ... # (you can use ctl-c to exit early)
17
 
18
  print(best(equations))
19
- print(f"Took {time.time()-start} seconds")
 
1
  import numpy as np
2
+ from pysr import pysr, best
 
3
 
4
  # Dataset
5
  X = 2*np.random.randn(100, 5)
6
  y = 2*np.cos(X[:, 3]) + X[:, 0]**2 - 2
7
 
 
8
  # Learn equations
 
9
  equations = pysr(X, y, niterations=5,
10
+ binary_operators=["plus", "mult"],
11
+ unary_operators=[
12
+ "cos", "exp", "sin", #Pre-defined library of operators (see https://pysr.readthedocs.io/en/latest/docs/operators/)
13
+ "inv(x) = 1/x"]) # Define your own operator! (Julia syntax)
14
 
15
+ ...# (you can use ctl-c to exit early)
16
 
17
  print(best(equations))
 
julia/{complexityChecks.jl β†’ CheckConstraints.jl} RENAMED
@@ -39,4 +39,4 @@ function flagUnaOperatorComplexity(tree::Node, op::Int)::Bool
39
  else
40
  return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
41
  end
42
- end
 
39
  else
40
  return (flagUnaOperatorComplexity(tree.l, op) || flagUnaOperatorComplexity(tree.r, op))
41
  end
42
+ end
julia/{optimization.jl β†’ ConstantOptimization.jl} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  # Proxy function for optimization
2
  function optFunc(x::Array{Float32, 1}, tree::Node)::Float32
3
  setConstants(tree, x)
@@ -44,4 +46,4 @@ function optimizeConstants(member::PopMember)::PopMember
44
  end
45
  end
46
  return member
47
- end
 
1
+ import Optim
2
+
3
  # Proxy function for optimization
4
  function optFunc(x::Array{Float32, 1}, tree::Node)::Float32
5
  setConstants(tree, x)
 
46
  end
47
  end
48
  return member
49
+ end
julia/{Node.jl β†’ Equation.jl} RENAMED
File without changes
julia/{eval.jl β†’ EvaluateEquation.jl} RENAMED
@@ -45,38 +45,3 @@ function evalTreeArray(tree::Node, cX::Array{Float32, 2})::Union{Array{Float32,
45
  return cumulator
46
  end
47
  end
48
-
49
- # Score an equation
50
- function scoreFunc(tree::Node)::Float32
51
- prediction = evalTreeArray(tree)
52
- if prediction === nothing
53
- return 1f9
54
- end
55
- if weighted
56
- mse = MSE(prediction, y, weights)
57
- else
58
- mse = MSE(prediction, y)
59
- end
60
- return mse / baselineMSE + countNodes(tree)*parsimony
61
- end
62
-
63
- # Score an equation with a small batch
64
- function scoreFuncBatch(tree::Node)::Float32
65
- # batchSize
66
- batch_idx = randperm(len)[1:batchSize]
67
- batch_X = X[batch_idx, :]
68
- prediction = evalTreeArray(tree, batch_X)
69
- if prediction === nothing
70
- return 1f9
71
- end
72
- size_adjustment = 1f0
73
- batch_y = y[batch_idx]
74
- if weighted
75
- batch_w = weights[batch_idx]
76
- mse = MSE(prediction, batch_y, batch_w)
77
- size_adjustment = 1f0 * len / batchSize
78
- else
79
- mse = MSE(prediction, batch_y)
80
- end
81
- return size_adjustment * mse / baselineMSE + countNodes(tree)*parsimony
82
- end
 
45
  return cumulator
46
  end
47
  end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
julia/{errors.jl β†’ LossFunctions.jl} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  # Sum of square error between two arrays
2
  function SSE(x::Array{Float32}, y::Array{Float32})::Float32
3
  diff = (x - y)
@@ -34,4 +36,47 @@ end
34
  # Mean of square error between two arrays
35
  function MSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
36
  return SSE(x, y, w)/sum(w)
37
- end
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import Random: randperm
2
+
3
  # Sum of square error between two arrays
4
  function SSE(x::Array{Float32}, y::Array{Float32})::Float32
5
  diff = (x - y)
 
36
  # Mean of square error between two arrays
37
  function MSE(x::Array{Float32}, y::Array{Float32}, w::Array{Float32})::Float32
38
  return SSE(x, y, w)/sum(w)
39
+ end
40
+
41
+ if weighted
42
+ const avgy = sum(y .* weights)/sum(weights)
43
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy), weights)
44
+ else
45
+ const avgy = sum(y)/len
46
+ const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
47
+ end
48
+
49
+ # Score an equation
50
+ function scoreFunc(tree::Node)::Float32
51
+ prediction = evalTreeArray(tree)
52
+ if prediction === nothing
53
+ return 1f9
54
+ end
55
+ if weighted
56
+ mse = MSE(prediction, y, weights)
57
+ else
58
+ mse = MSE(prediction, y)
59
+ end
60
+ return mse / baselineMSE + countNodes(tree)*parsimony
61
+ end
62
+
63
+ # Score an equation with a small batch
64
+ function scoreFuncBatch(tree::Node)::Float32
65
+ # batchSize
66
+ batch_idx = randperm(len)[1:batchSize]
67
+ batch_X = X[batch_idx, :]
68
+ prediction = evalTreeArray(tree, batch_X)
69
+ if prediction === nothing
70
+ return 1f9
71
+ end
72
+ size_adjustment = 1f0
73
+ batch_y = y[batch_idx]
74
+ if weighted
75
+ batch_w = weights[batch_idx]
76
+ mse = MSE(prediction, batch_y, batch_w)
77
+ size_adjustment = 1f0 * len / batchSize
78
+ else
79
+ mse = MSE(prediction, batch_y)
80
+ end
81
+ return size_adjustment * mse / baselineMSE + countNodes(tree)*parsimony
82
+ end
julia/{simulatedAnnealing.jl β†’ Mutate.jl} RENAMED
@@ -1,5 +1,4 @@
1
- # Go through one simulated annealing mutation cycle
2
- # exp(-delta/T) defines probability of accepting a change
3
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer, frequencyComplexity::Array{Float32, 1})::PopMember
4
  prev = member.tree
5
  tree = prev
@@ -122,4 +121,4 @@ function iterate(member::PopMember, T::Float32, curmaxsize::Integer, frequencyCo
122
  end
123
  end
124
  return PopMember(tree, afterLoss)
125
- end
 
1
+ # Go through one mutation cycle
 
2
  function iterate(member::PopMember, T::Float32, curmaxsize::Integer, frequencyComplexity::Array{Float32, 1})::PopMember
3
  prev = member.tree
4
  tree = prev
 
121
  end
122
  end
123
  return PopMember(tree, afterLoss)
124
+ end
julia/{randomMutations.jl β†’ MutationFunctions.jl} RENAMED
File without changes
julia/{operators.jl β†’ Operators.jl} RENAMED
File without changes
julia/{constants.jl β†’ ProgramConstants.jl} RENAMED
File without changes
julia/{regEvolCycle.jl β†’ RegularizedEvolution.jl} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  # Pass through the population several times, replacing the oldest
2
  # with the fittest of a small subsample
3
  function regEvolCycle(pop::Population, T::Float32, curmaxsize::Integer,
@@ -41,4 +43,4 @@ function regEvolCycle(pop::Population, T::Float32, curmaxsize::Integer,
41
  end
42
 
43
  return pop
44
- end
 
1
+ import Random: shuffle!
2
+
3
  # Pass through the population several times, replacing the oldest
4
  # with the fittest of a small subsample
5
  function regEvolCycle(pop::Population, T::Float32, curmaxsize::Integer,
 
43
  end
44
 
45
  return pop
46
+ end
julia/{simplification.jl β†’ SimplifyEquation.jl} RENAMED
File without changes
julia/{run.jl β†’ SingleIteration.jl} RENAMED
File without changes
julia/{utils.jl β†’ Utils.jl} RENAMED
@@ -1,3 +1,5 @@
 
 
1
  function id(x::Float32)::Float32
2
  x
3
  end
@@ -29,4 +31,4 @@ function testConfiguration()
29
  @printf("\n\nYour configuration is invalid - one of your operators is not well-defined over the real line.\n\n\n")
30
  throw(error)
31
  end
32
- end
 
1
+ import Printf: @printf
2
+
3
  function id(x::Float32)::Float32
4
  x
5
  end
 
31
  @printf("\n\nYour configuration is invalid - one of your operators is not well-defined over the real line.\n\n\n")
32
  throw(error)
33
  end
34
+ end
julia/halloffame.jl CHANGED
@@ -5,4 +5,4 @@ mutable struct HallOfFame
5
 
6
  # Arranged by complexity - store one at each.
7
  HallOfFame() = new([PopMember(Node(1f0), 1f9) for i=1:actualMaxsize], [false for i=1:actualMaxsize])
8
- end
 
5
 
6
  # Arranged by complexity - store one at each.
7
  HallOfFame() = new([PopMember(Node(1f0), 1f9) for i=1:actualMaxsize], [false for i=1:actualMaxsize])
8
+ end
julia/sr.jl CHANGED
@@ -1,49 +1,4 @@
1
- import Optim
2
  import Printf: @printf
3
- import Random: shuffle!, randperm
4
-
5
-
6
- include("constants.jl")
7
-
8
- include("errors.jl")
9
-
10
- if weighted
11
- const avgy = sum(y .* weights)/sum(weights)
12
- const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy), weights)
13
- else
14
- const avgy = sum(y)/len
15
- const baselineMSE = MSE(y, convert(Array{Float32, 1}, ones(len) .* avgy))
16
- end
17
-
18
- include("utils.jl")
19
-
20
- include("Node.jl")
21
-
22
- include("eval.jl")
23
-
24
- include("randomMutations.jl")
25
-
26
- include("simplification.jl")
27
-
28
- include("PopMember.jl")
29
-
30
-
31
- include("halloffame.jl")
32
-
33
-
34
- include("complexityChecks.jl")
35
-
36
- include("simulatedAnnealing.jl")
37
-
38
- include("Population.jl")
39
-
40
- include("regEvolCycle.jl")
41
-
42
- include("run.jl")
43
-
44
- include("optimization.jl")
45
-
46
-
47
 
48
  function fullRun(niterations::Integer;
49
  npop::Integer=300,
 
 
1
  import Printf: @printf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  function fullRun(niterations::Integer;
4
  npop::Integer=300,
julia/truth.jl ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # *** Custom Functions
2
+ ##################################################################################################################################
3
+ # *** Will somewhere need to define a list TRUTHS of all valid auxliary truths
4
+ struct Transformation
5
+ type::Integer # 1 is symmetry, 2 is zero, 3 is equality
6
+ params::Array{Int32}
7
+ Transformation(type::Integer, params::Array{Int32}) = new(type, params)
8
+ Transformation(type::Integer, params::Array{Int64}) = new(type, params)
9
+
10
+ end
11
+ struct Truth
12
+ transformation::Transformation
13
+ weights::Array{Float32}
14
+ Truth(transformation::Transformation, weights::Array{Float32}) = new(transformation, weights)
15
+ Truth(type::Int64, params::Array{Int64}, weights::Array{Float32}) = new(Transformation(type, params), weights)
16
+ Truth(transformation::Transformation, weights::Array{Float64}) = new(transformation, weights)
17
+ Truth(type::Int64, params::Array{Int64}, weights::Array{Float64}) = new(Transformation(type, params), weights)
18
+ end
19
+ # Returns a linear combination when given X of shape nxd, y of shape nx1 is f(x) and w of shape d+2x1, result is shape nx1
20
+ function LinearPrediction(cX::Array{Float32}, cy::Array{Float32}, w::Array{Float32})::Array{Float32}
21
+ preds = 0
22
+ for i in 1:ndims(cX)
23
+ preds = preds .+ cX[:,i].*w[i]
24
+ end
25
+ preds = preds .+ cy.*w[ndims(cX)+1]
26
+ return preds .+ w[ndims(cX)+2]
27
+ end
28
+
29
+ # Returns a copy of the data with the two specified columns swapped
30
+ function swapColumns(cX::Array{Float32, 2}, a::Integer, b::Integer)::Array{Float32, 2}
31
+ X1 = copy(cX)
32
+ X1[:, a] = cX[:, b]
33
+ X1[:, b] = cX[:, a]
34
+ return X1
35
+ end
36
+
37
+ # Returns a copy of the data with the specified integers in the list set to value given
38
+ function setVal(cX::Array{Float32, 2}, a::Array{Int32, 1}, val::Float32)::Array{Float32, 2}
39
+ X1 = copy(cX)
40
+ for i in 1:size(a)[1]
41
+ X1[:, a[i]] = fill!(cX[:, a[i]], val)
42
+ end
43
+ return X1
44
+ end
45
+
46
+ # Returns a copy of the data with the specified integer indices in the list set to the first item of that list
47
+ function setEq(cX::Array{Float32, 2}, a::Array{Int32, 1})::Array{Float32, 2}
48
+ X1 = copy(cX)
49
+ val = X1[:, a[1]]
50
+ for i in 1:size(a)[1]
51
+ X1[:, a[i]] = val
52
+ end
53
+ return X1
54
+ end
55
+
56
+ # Takes in a dataset and returns the transformed version of it as per the specified type and parameters
57
+ function transform(cX::Array{Float32, 2}, transformation::Transformation)::Array{Float32, 2}
58
+ if transformation.type==1 # then symmetry
59
+ a = transformation.params[1]
60
+ b = transformation.params[2]
61
+ return swapColumns(cX, a, b)
62
+ elseif transformation.type==2 # then zero condition
63
+ return setVal(cX, transformation.params, Float32(0))
64
+ elseif transformation.type == 3 # then equality condition
65
+ return setEq(cX, transformation.params)
66
+ else # Then error return X
67
+ return cX
68
+ end
69
+ end
70
+ function transform(cX::Array{Float32, 2}, truth::Truth)::Array{Float32, 2}
71
+ return transform(cX, truth.transformation)
72
+ end
73
+
74
+ # Takes in X that has been transformed and returns what the Truth projects the target values should be
75
+ function truthPrediction(X_transformed::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Array{Float32}
76
+ return LinearPrediction(X_transformed, cy, truth.weights)
77
+ end
julia/truthPops.jl ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Returns the MSE between the predictions and the truth provided targets for the given dataset
2
+ function truthScore(member::PopMember, cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Float32
3
+ transformed = transform(cX, truth)
4
+ targets = truthPrediction(transformed, cy, truth)
5
+ preds = evalTreeArray(member.tree, transformed)
6
+ return MSE(preds, targets)
7
+ end
8
+
9
+ # Assumes a dataset X, y for a given truth
10
+ function truthScore(member::PopMember, truth::Truth)::Float32
11
+ return truthScore(member, X, y, truth)
12
+ end
13
+
14
+ # Assumes a list of Truths TRUTHS is defined. Performs the truthScore function for each of them and returns the average
15
+ function truthScore(member::PopMember, cX::Array{Float32, 2}, cy::Array{Float32})::Float32
16
+ s = 0
17
+ for truth in TRUTHS
18
+ s += (truthScore(member, cX, cy, truth))/size(TRUTHS)[1]
19
+ end
20
+ return s
21
+ end
22
+
23
+ # Assumes list of Truths TRUTHS and dataset X, y are defined
24
+ function truthScore(member::PopMember)::Float32
25
+ return truthScore(member, X, y)
26
+ end
27
+ # Returns the MSE between the predictions and the truth provided targets for the given dataset
28
+ function truthScore(tree::Node, cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth)::Float32
29
+ transformed = transform(cX, truth)
30
+ targets = truthPrediction(transformed, cy, truth)
31
+ preds = evalTreeArray(tree, transformed)
32
+ return MSE(preds, targets)
33
+ end
34
+
35
+ # Assumes a dataset X, y for a given truth
36
+ function truthScore(tree::Node, truth::Truth)::Float32
37
+ return truthScore(tree, X, y, truth)
38
+ end
39
+
40
+ # Assumes a list of Truths TRUTHS is defined. Performs the truthScore function for each of them and returns the average
41
+ function truthScore(tree::Node, cX::Array{Float32, 2}, cy::Array{Float32})::Float32
42
+ s = 0
43
+ for truth in TRUTHS
44
+ s += (truthScore(tree, cX, cy, truth))/size(TRUTHS)[1]
45
+ end
46
+ return s
47
+ end
48
+
49
+ # Assumes list of Truths TRUTHS and dataset X, y are defined
50
+ function truthScore(tree::Node)::Float32
51
+ return truthScore(tree, X, y)
52
+ end
53
+
54
+ # Returns true iff Truth Score is below a given threshold i.e truth is satisfied
55
+ function testTruth(member::PopMember, truth::Truth, threshold::Float32=Float32(1.0e-8))::Bool
56
+ truthError = truthScore(member, truth)
57
+ #print(stringTree(member.tree), "\n")
58
+ #print(truth, ": ")
59
+ #print(truthError, "\n")
60
+ if truthError > threshold
61
+ #print("Returns False \n ----\n")
62
+ return false
63
+ else
64
+ #print("Returns True \n ----\n")
65
+ return true
66
+ end
67
+ end
68
+
69
+ # Returns a list of violating functions from assumed list TRUTHS
70
+ function violatingTruths(member::PopMember)::Array{Truth}
71
+ return violatingTruths(member.tree)
72
+ end
73
+
74
+ # Returns true iff Truth Score is below a given threshold i.e truth is satisfied
75
+ function testTruth(tree::Node, truth::Truth, threshold::Float32=Float32(1.0e-3))::Bool
76
+ truthError = truthScore(tree, truth)
77
+ if truthError > threshold
78
+ return false
79
+ else
80
+ return true
81
+ end
82
+ end
83
+
84
+ # Returns a list of violating functions from assumed list TRUTHS
85
+ function violatingTruths(tree::Node)::Array{Truth}
86
+ toReturn = []
87
+ #print("\n Checking Equation ", stringTree(tree), "\n")
88
+ for truth in TRUTHS
89
+ test_truth = testTruth(tree, truth)
90
+ #print("Truth: ", truth, ": " , test_truth, "\n-----\n")
91
+ if !test_truth
92
+ append!(toReturn, [truth])
93
+ end
94
+ end
95
+ return toReturn
96
+ end
97
+
98
+ function randomIndex(cX::Array{Float32, 2}, k::Integer=10)::Array{Int32, 1}
99
+ indxs = sample([Int32(i) for i in 1:size(cX)[1]], k)
100
+ return indxs
101
+ end
102
+
103
+ function randomIndex(leng::Integer, k::Integer=10)::Array{Int32, 1}
104
+ indxs = sample([Int32(i) for i in 1:leng], k)
105
+ return indxs
106
+ end
107
+
108
+ function extendedX(cX::Array{Float32, 2}, truth::Truth, indx::Array{Int32, 1})::Array{Float32, 2}
109
+ workingcX = copy(cX)
110
+ X_slice = workingcX[indx, :]
111
+ X_transformed = transform(X_slice, truth)
112
+ return X_transformed
113
+ end
114
+ function extendedX(truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
115
+ return extendedX(OGX, truth, indx)
116
+ end
117
+ function extendedX(cX::Array{Float32, 2}, violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
118
+ if length(violatedTruths) == 0
119
+ return nothing
120
+ end
121
+ workingX = extendedX(cX, violatedTruths[1], indx)
122
+ for truth in violatedTruths[2:length(violatedTruths)]
123
+ workingX = vcat(workingX, extendedX(cX, truth, indx))
124
+ end
125
+ return workingX
126
+ end
127
+ function extendedX(violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
128
+ return extendedX(OGX, violatedTruths, indx)
129
+ end
130
+ function extendedX(tree::Node, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
131
+ violatedTruths = violatingTruths(tree)
132
+ return extendedX(violatedTruths, indx)
133
+ end
134
+ function extendedX(member::PopMember, indx::Array{Int32, 1})::Union{Array{Float32, 2}, Nothing}
135
+ return extendedX(member.tree, indx)
136
+ end
137
+
138
+
139
+ function extendedy(cX::Array{Float32, 2}, cy::Array{Float32}, truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
140
+ cy = copy(cy)
141
+ cX = copy(cX)
142
+ X_slice = cX[indx, :]
143
+ y_slice = cy[indx]
144
+ X_transformed = transform(X_slice, truth)
145
+ y_transformed = truthPrediction(X_transformed, y_slice, truth)
146
+ return y_transformed
147
+ end
148
+ function extendedy(truth::Truth, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
149
+ return extendedy(OGX, OGy, truth, indx)
150
+ end
151
+ function extendedy(cX::Array{Float32, 2}, cy::Array{Float32}, violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
152
+ if length(violatedTruths) == 0
153
+ return nothing
154
+ end
155
+ workingy = extendedy(cX, cy, violatedTruths[1], indx)
156
+ for truth in violatedTruths[2:length(violatedTruths)]
157
+ workingy = vcat(workingy, extendedy(cX, cy, truth, indx))
158
+ end
159
+ return workingy
160
+ end
161
+ function extendedy(violatedTruths::Array{Truth}, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
162
+ return extendedy(OGX,OGy, violatedTruths, indx)
163
+ end
164
+ function extendedy(tree::Node, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
165
+ violatedTruths = violatingTruths(tree)
166
+ return extendedy(violatedTruths, indx)
167
+ end
168
+ function extendedy(member::PopMember, indx::Array{Int32, 1})::Union{Array{Float32}, Nothing}
169
+ return extendedy(member.tree, indx)
170
+ end
pysr/sr.py CHANGED
@@ -227,11 +227,13 @@ def pysr(X=None, y=None, weights=None,
227
  # Add pre-defined functions to Julia
228
  def_hyperparams = predefined_function_addition(binary_operators, def_hyperparams, unary_operators)
229
 
 
230
  #arbitrary complexity by default
231
  handle_constraints(binary_operators, constraints, unary_operators)
232
 
233
  constraints_str = make_constraints_str(binary_operators, constraints, unary_operators)
234
 
 
235
  def_hyperparams = make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary_operators,
236
  constraints_str, def_hyperparams, equation_file, fast_cycle,
237
  fractionReplacedHof, hofMigration, limitPowComplexity, maxdepth,
 
227
  # Add pre-defined functions to Julia
228
  def_hyperparams = predefined_function_addition(binary_operators, def_hyperparams, unary_operators)
229
 
230
+
231
  #arbitrary complexity by default
232
  handle_constraints(binary_operators, constraints, unary_operators)
233
 
234
  constraints_str = make_constraints_str(binary_operators, constraints, unary_operators)
235
 
236
+
237
  def_hyperparams = make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary_operators,
238
  constraints_str, def_hyperparams, equation_file, fast_cycle,
239
  fractionReplacedHof, hofMigration, limitPowComplexity, maxdepth,