Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
ea010a7
1
Parent(s):
34fadcf
Catch domain errors during classical optimization
Browse files- julia/sr.jl +22 -13
- pysr/sr.py +1 -1
julia/sr.jl
CHANGED
@@ -689,21 +689,30 @@ function optimizeConstants(member::PopMember)::PopMember
|
|
689 |
algorithm = Optim.NelderMead
|
690 |
end
|
691 |
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
|
|
|
|
698 |
end
|
699 |
-
end
|
700 |
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
707 |
end
|
708 |
return member
|
709 |
end
|
|
|
689 |
algorithm = Optim.NelderMead
|
690 |
end
|
691 |
|
692 |
+
try
|
693 |
+
result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
|
694 |
+
# Try other initial conditions:
|
695 |
+
for i=1:nrestarts
|
696 |
+
tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
|
697 |
+
if tmpresult.minimum < result.minimum
|
698 |
+
result = tmpresult
|
699 |
+
end
|
700 |
end
|
|
|
701 |
|
702 |
+
if Optim.converged(result)
|
703 |
+
setConstants(member.tree, result.minimizer)
|
704 |
+
member.score = convert(Float32, result.minimum)
|
705 |
+
member.birth = getTime()
|
706 |
+
else
|
707 |
+
setConstants(member.tree, x0)
|
708 |
+
end
|
709 |
+
catch error
|
710 |
+
# Fine if optimization encountered domain error, just return x0
|
711 |
+
if isa(error, AssertionError)
|
712 |
+
setConstants(member.tree, x0)
|
713 |
+
else
|
714 |
+
throw(error)
|
715 |
+
end
|
716 |
end
|
717 |
return member
|
718 |
end
|
pysr/sr.py
CHANGED
@@ -113,7 +113,7 @@ def pysr(X=None, y=None, threads=4,
|
|
113 |
y = eval(eval_str)
|
114 |
print("Running on", eval_str)
|
115 |
|
116 |
-
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['
|
117 |
|
118 |
def_hyperparams = f"""include("{pkg_directory}/operators.jl")
|
119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|
|
|
113 |
y = eval(eval_str)
|
114 |
print("Running on", eval_str)
|
115 |
|
116 |
+
pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
|
117 |
|
118 |
def_hyperparams = f"""include("{pkg_directory}/operators.jl")
|
119 |
const binops = {'[' + ', '.join(binary_operators) + ']'}
|