File size: 2,028 Bytes
ce60798
4037c2d
32b2f6e
6baa534
976f8d8
32b2f6e
 
 
 
 
4037c2d
d325c42
4037c2d
 
 
 
 
4a5693e
 
 
 
 
 
 
f335ea1
 
 
 
 
 
 
 
4037c2d
7ea0b2f
 
e530637
d72c643
44ff874
4037c2d
6baa534
 
80fecb9
6baa534
 
 
4bc0a76
 
68ea1be
 
 
a4bb529
 
 
 
 
 
d72c643
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Functions for initializing the Julia environment and installing deps."""
import os
import sys
import warnings

if "juliacall" in sys.modules:
    warnings.warn(
        "juliacall module already imported. Make sure that you have set `PYTHON_JULIACALL_HANDLE_SIGNALS=yes` to avoid segfaults."
    )

# Required to avoid segfaults (https://juliapy.github.io/PythonCall.jl/dev/faq/)
if os.environ.get("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes") != "yes":
    warnings.warn(
        "PYTHON_JULIACALL_HANDLE_SIGNALS environment variable is set to something other than 'yes' or ''. "
        + "You will experience segfaults if running with multithreading."
    )

if os.environ.get("JULIA_NUM_THREADS", "auto") != "auto":
    warnings.warn(
        "JULIA_NUM_THREADS environment variable is set to something other than 'auto', "
        "so PySR was not able to set it. You may wish to set it to `'auto'` for full use "
        "of your CPU."
    )

# TODO: Remove these when juliapkg lets you specify this
for k, default in (
    ("PYTHON_JULIACALL_HANDLE_SIGNALS", "yes"),
    ("JULIA_NUM_THREADS", "auto"),
    ("JULIA_OPTIMIZE", "3"),
):
    os.environ[k] = os.environ.get(k, default)


from juliacall import Main as jl  # type: ignore
from juliacall import convert as jl_convert  # type: ignore

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")


def _escape_filename(filename):
    """Turn a path into a string with correctly escaped backslashes."""
    str_repr = str(filename)
    str_repr = str_repr.replace("\\", "\\\\")
    return str_repr


def _load_cluster_manager(cluster_manager):
    jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
    return jl.seval(f"addprocs_{cluster_manager}")


def jl_array(x):
    if x is None:
        return None
    return jl_convert(jl.Array, x)


def jl_deserialize_s(s):
    if s is None:
        return s
    buf = jl.IOBuffer()
    jl.write(buf, jl_array(s))
    jl.seekstart(buf)
    return jl.Serialization.deserialize(buf)