MilesCranmer commited on
Commit
ea010a7
1 Parent(s): 34fadcf

Catch domain errors during classical optimization

Browse files
Files changed (2) hide show
  1. julia/sr.jl +22 -13
  2. pysr/sr.py +1 -1
julia/sr.jl CHANGED
@@ -689,21 +689,30 @@ function optimizeConstants(member::PopMember)::PopMember
689
  algorithm = Optim.NelderMead
690
  end
691
 
692
- result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
693
- # Try other initial conditions:
694
- for i=1:nrestarts
695
- tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
696
- if tmpresult.minimum < result.minimum
697
- result = tmpresult
 
 
698
  end
699
- end
700
 
701
- if Optim.converged(result)
702
- setConstants(member.tree, result.minimizer)
703
- member.score = convert(Float32, result.minimum)
704
- member.birth = getTime()
705
- else
706
- setConstants(member.tree, x0)
 
 
 
 
 
 
 
 
707
  end
708
  return member
709
  end
 
689
  algorithm = Optim.NelderMead
690
  end
691
 
692
+ try
693
+ result = Optim.optimize(f, x0, algorithm(), Optim.Options(iterations=100))
694
+ # Try other initial conditions:
695
+ for i=1:nrestarts
696
+ tmpresult = Optim.optimize(f, x0 .* (1f0 .+ 5f-1*randn(Float32, size(x0)[1])), algorithm(), Optim.Options(iterations=100))
697
+ if tmpresult.minimum < result.minimum
698
+ result = tmpresult
699
+ end
700
  end
 
701
 
702
+ if Optim.converged(result)
703
+ setConstants(member.tree, result.minimizer)
704
+ member.score = convert(Float32, result.minimum)
705
+ member.birth = getTime()
706
+ else
707
+ setConstants(member.tree, x0)
708
+ end
709
+ catch error
710
+ # Fine if optimization encountered domain error, just return x0
711
+ if isa(error, AssertionError)
712
+ setConstants(member.tree, x0)
713
+ else
714
+ throw(error)
715
+ end
716
  end
717
  return member
718
  end
pysr/sr.py CHANGED
@@ -113,7 +113,7 @@ def pysr(X=None, y=None, threads=4,
113
  y = eval(eval_str)
114
  print("Running on", eval_str)
115
 
116
- pkg_directory = '/'.join(__file__.split('/')[:-2] + ['/julia'])
117
 
118
  def_hyperparams = f"""include("{pkg_directory}/operators.jl")
119
  const binops = {'[' + ', '.join(binary_operators) + ']'}
 
113
  y = eval(eval_str)
114
  print("Running on", eval_str)
115
 
116
+ pkg_directory = '/'.join(__file__.split('/')[:-2] + ['julia'])
117
 
118
  def_hyperparams = f"""include("{pkg_directory}/operators.jl")
119
  const binops = {'[' + ', '.join(binary_operators) + ']'}