MilesCranmer commited on
Commit
53698db
1 Parent(s): 52a6b5b

Add warning message in case a user import torch

Browse files
Files changed (1) hide show
  1. pysr/julia_helpers.py +17 -0
pysr/julia_helpers.py CHANGED
@@ -1,4 +1,5 @@
1
  """Functions for initializing the Julia environment and installing deps."""
 
2
  import warnings
3
  from pathlib import Path
4
  import os
@@ -83,8 +84,24 @@ def is_julia_version_greater_eq(Main, version="1.6"):
83
  return Main.eval(f'VERSION >= v"{version}"')
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def init_julia(julia_project=None):
87
  """Initialize julia binary, turning off compiled modules if needed."""
 
88
  from julia.core import JuliaInfo, UnsupportedPythonError
89
 
90
  julia_project, is_shared = _get_julia_project(julia_project)
 
1
  """Functions for initializing the Julia environment and installing deps."""
2
+ import sys
3
  import warnings
4
  from pathlib import Path
5
  import os
 
84
  return Main.eval(f'VERSION >= v"{version}"')
85
 
86
 
87
+ def check_for_conflicting_libraries(): # pragma: no cover
88
+ """Check whether there are conflicting modules, and display warnings."""
89
+ # See https://github.com/pytorch/pytorch/issues/78829: importing
90
+ # pytorch before running `pysr.fit` causes a segfault.
91
+ torch_is_loaded = "torch" in sys.modules
92
+ if torch_is_loaded:
93
+ warnings.warn(
94
+ "`torch` was loaded before the Julia instance started. "
95
+ "This may cause a segfault when running `PySRRegressor.fit`. "
96
+ "To avoid this, please run `pysr.julia_helpers.init_julia()` *before* "
97
+ "importing `torch`. "
98
+ "For updates, see https://github.com/pytorch/pytorch/issues/78829"
99
+ )
100
+
101
+
102
  def init_julia(julia_project=None):
103
  """Initialize julia binary, turning off compiled modules if needed."""
104
+ check_for_conflicting_libraries()
105
  from julia.core import JuliaInfo, UnsupportedPythonError
106
 
107
  julia_project, is_shared = _get_julia_project(julia_project)