File size: 1,594 Bytes
ce60798
8cf4bf3
810bea9
fd4c500
70b842a
7ea0b2f
810bea9
e530637
d16abb4
caf58a4
 
fd4c500
 
d72c643
44ff874
4037c2d
70b842a
 
 
c5c6896
ffc5f5c
6baa534
 
80fecb9
6baa534
 
 
4bc0a76
 
810bea9
68ea1be
 
a4bb529
 
88d93a1
a4bb529
 
88d93a1
 
 
 
 
 
9f3b918
 
d72c643
 
810bea9
70b842a
 
 
 
 
810bea9
d72c643
 
 
 
 
70b842a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Functions for initializing the Julia environment and installing deps."""

from typing import Any, Callable, Union, cast

import numpy as np
from juliacall import convert as jl_convert  # type: ignore
from numpy.typing import NDArray

from .deprecated import init_julia, install
from .julia_import import jl

jl_convert = cast(Callable[[Any, Any], Any], jl_convert)

jl.seval("using Serialization: Serialization")
jl.seval("using PythonCall: PythonCall")

Serialization = jl.Serialization
PythonCall = jl.PythonCall

jl.seval("using SymbolicRegression: plus, sub, mult, div, pow")


def _escape_filename(filename):
    """Turn a path into a string with correctly escaped backslashes."""
    str_repr = str(filename)
    str_repr = str_repr.replace("\\", "\\\\")
    return str_repr


def _load_cluster_manager(cluster_manager: str):
    jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
    return jl.seval(f"addprocs_{cluster_manager}")


def jl_array(x, dtype=None):
    if x is None:
        return None
    elif dtype is None:
        return jl_convert(jl.Array, x)
    else:
        return jl_convert(jl.Array[dtype], x)


def jl_is_function(f) -> bool:
    return cast(bool, jl.seval("op -> op isa Function")(f))


def jl_serialize(obj: Any) -> NDArray[np.uint8]:
    buf = jl.IOBuffer()
    Serialization.serialize(buf, obj)
    return np.array(jl.take_b(buf))


def jl_deserialize(s: Union[NDArray[np.uint8], None]):
    if s is None:
        return s
    buf = jl.IOBuffer()
    jl.write(buf, jl_array(s))
    jl.seekstart(buf)
    return Serialization.deserialize(buf)