File size: 2,306 Bytes
ce60798
4037c2d
32b2f6e
6baa534
976f8d8
32b2f6e
 
 
 
 
4037c2d
d325c42
4037c2d
 
 
 
 
 
32b2f6e
4037c2d
68ea1be
a4bb529
e530637
 
d72c643
44ff874
4037c2d
3856951
d3026af
4300bea
d42f10b
d3026af
6baa534
bcaffe2
 
68ea1be
6baa534
 
 
80fecb9
6baa534
 
 
4bc0a76
 
68ea1be
 
 
 
4ee8cdb
68ea1be
 
4ee8cdb
 
 
 
68ea1be
 
 
a4bb529
 
 
 
 
 
d72c643
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Functions for initializing the Julia environment and installing deps."""
import os
import sys
import warnings

if "juliacall" in sys.modules:
    warnings.warn(
        "juliacall module already imported. Make sure that you have set `PYTHON_JULIACALL_HANDLE_SIGNALS=yes` to avoid segfaults."
    )

# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
    warnings.warn(
        "PYTHON_JULIACALL_HANDLE_SIGNALS environment variable is set to something other than 'yes' or ''. "
        + "You will experience segfaults if running with multithreading."
    )

os.environ["PYTHON_JULIACALL_HANDLE_SIGNALS"] = "yes"
os.environ["JULIA_NUM_THREADS"] = os.environ.get("JULIA_NUM_THREADS", "auto")

import juliapkg
from juliacall import Main as jl
from juliacall import convert as jl_convert

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")

juliainfo = None
julia_initialized = False
julia_kwargs_at_initialization = None
julia_activated_env = None


def _get_io_arg(quiet):
    io = "devnull" if quiet else "stderr"
    return f"io={io}"


def _escape_filename(filename):
    """Turn a path into a string with correctly escaped backslashes."""
    str_repr = str(filename)
    str_repr = str_repr.replace("\\", "\\\\")
    return str_repr


def _backend_version_assertion():
    backend_version = jl.seval("string(SymbolicRegression.PACKAGE_VERSION)")
    expected_backend_version = juliapkg.status(target="SymbolicRegression").version
    if backend_version != expected_backend_version:  # pragma: no cover
        warnings.warn(
            f"PySR backend (SymbolicRegression.jl) version {backend_version} "
            f"does not match expected version {expected_backend_version}. "
            "Things may break. "
        )


def _load_cluster_manager(cluster_manager):
    jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
    return jl.seval(f"addprocs_{cluster_manager}")


def jl_array(x):
    if x is None:
        return None
    return jl_convert(jl.Array, x)


def jl_deserialize_s(s):
    if s is None:
        return s
    buf = jl.IOBuffer()
    jl.write(buf, jl_array(s))
    jl.seekstart(buf)
    return jl.Serialization.deserialize(buf)