PySR / julia /Equation.jl
MilesCranmer's picture
Make filenaming more consistent
48d465b
raw
history blame
5.1 kB
# Define a serialization format for the symbolic equations:
mutable struct Node
#Holds operators, variables, constants in a tree
degree::Integer #0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
val::Union{Float32, Integer} #Either const value, or enumerates variable
constant::Bool #false if variable
op::Integer #enumerates operator (separately for degree=1,2)
l::Union{Node, Nothing}
r::Union{Node, Nothing}
Node(val::Float32) = new(0, val, true, 1, nothing, nothing)
Node(val::Integer) = new(0, val, false, 1, nothing, nothing)
Node(op::Integer, l::Node) = new(1, 0.0f0, false, op, l, nothing)
Node(op::Integer, l::Union{Float32, Integer}) = new(1, 0.0f0, false, op, Node(l), nothing)
Node(op::Integer, l::Node, r::Node) = new(2, 0.0f0, false, op, l, r)
#Allow to pass the leaf value without additional node call:
Node(op::Integer, l::Union{Float32, Integer}, r::Node) = new(2, 0.0f0, false, op, Node(l), r)
Node(op::Integer, l::Node, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, l, Node(r))
Node(op::Integer, l::Union{Float32, Integer}, r::Union{Float32, Integer}) = new(2, 0.0f0, false, op, Node(l), Node(r))
end
# Copy an equation (faster than deepcopy)
function copyNode(tree::Node)::Node
if tree.degree == 0
return Node(tree.val)
elseif tree.degree == 1
return Node(tree.op, copyNode(tree.l))
else
return Node(tree.op, copyNode(tree.l), copyNode(tree.r))
end
end
# Count the operators, constants, variables in an equation
function countNodes(tree::Node)::Integer
if tree.degree == 0
return 1
elseif tree.degree == 1
return 1 + countNodes(tree.l)
else
return 1 + countNodes(tree.l) + countNodes(tree.r)
end
end
# Count the max depth of a tree
function countDepth(tree::Node)::Integer
if tree.degree == 0
return 1
elseif tree.degree == 1
return 1 + countDepth(tree.l)
else
return 1 + max(countDepth(tree.l), countDepth(tree.r))
end
end
# Convert an equation to a string
function stringTree(tree::Node)::String
if tree.degree == 0
if tree.constant
return string(tree.val)
else
if useVarMap
return varMap[tree.val]
else
return "x$(tree.val - 1)"
end
end
elseif tree.degree == 1
return "$(unaops[tree.op])($(stringTree(tree.l)))"
else
return "$(binops[tree.op])($(stringTree(tree.l)), $(stringTree(tree.r)))"
end
end
# Print an equation
function printTree(tree::Node)
println(stringTree(tree))
end
# Return a random node from the tree
function randomNode(tree::Node)::Node
if tree.degree == 0
return tree
end
a = countNodes(tree)
b = 0
c = 0
if tree.degree >= 1
b = countNodes(tree.l)
end
if tree.degree == 2
c = countNodes(tree.r)
end
i = rand(1:1+b+c)
if i <= b
return randomNode(tree.l)
elseif i == b + 1
return tree
end
return randomNode(tree.r)
end
# Count the number of unary operators in the equation
function countUnaryOperators(tree::Node)::Integer
if tree.degree == 0
return 0
elseif tree.degree == 1
return 1 + countUnaryOperators(tree.l)
else
return 0 + countUnaryOperators(tree.l) + countUnaryOperators(tree.r)
end
end
# Count the number of binary operators in the equation
function countBinaryOperators(tree::Node)::Integer
if tree.degree == 0
return 0
elseif tree.degree == 1
return 0 + countBinaryOperators(tree.l)
else
return 1 + countBinaryOperators(tree.l) + countBinaryOperators(tree.r)
end
end
# Count the number of operators in the equation
function countOperators(tree::Node)::Integer
return countUnaryOperators(tree) + countBinaryOperators(tree)
end
# Count the number of constants in an equation
function countConstants(tree::Node)::Integer
if tree.degree == 0
return convert(Integer, tree.constant)
elseif tree.degree == 1
return 0 + countConstants(tree.l)
else
return 0 + countConstants(tree.l) + countConstants(tree.r)
end
end
# Get all the constants from a tree
function getConstants(tree::Node)::Array{Float32, 1}
if tree.degree == 0
if tree.constant
return [tree.val]
else
return Float32[]
end
elseif tree.degree == 1
return getConstants(tree.l)
else
both = [getConstants(tree.l), getConstants(tree.r)]
return [constant for subtree in both for constant in subtree]
end
end
# Set all the constants inside a tree
function setConstants(tree::Node, constants::Array{Float32, 1})
if tree.degree == 0
if tree.constant
tree.val = constants[1]
end
elseif tree.degree == 1
setConstants(tree.l, constants)
else
numberLeft = countConstants(tree.l)
setConstants(tree.l, constants)
setConstants(tree.r, constants[numberLeft+1:end])
end
end