File size: 1,112 Bytes
e9b2ee8
58cb6bc
59d9435
 
 
e9b2ee8
 
3ffd1fe
59d9435
 
 
 
 
3ffd1fe
e9b2ee8
3ffd1fe
502e3ec
 
e9b2ee8
3ffd1fe
e9b2ee8
3ffd1fe
e9b2ee8
 
f932bcc
59d9435
 
 
 
 
 
 
 
 
6f2439d
e9b2ee8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""This file installs and loads extensions for SymbolicRegression."""

from typing import Optional

from .julia_import import Pkg, jl


def load_required_packages(
    *,
    turbo: bool = False,
    bumper: bool = False,
    enable_autodiff: bool = False,
    cluster_manager: Optional[str] = None,
):
    if turbo:
        load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
    if bumper:
        load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e")
    if enable_autodiff:
        load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f")
    if cluster_manager is not None:
        load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")


def isinstalled(uuid_s: str):
    return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))


def load_package(package_name: str, uuid_s: str) -> None:
    if not isinstalled(uuid_s):
        Pkg.add(name=package_name, uuid=uuid_s)

    # TODO: Protect against loading the same symbol from two packages,
    #       maybe with a @gensym here.
    jl.seval(f"using {package_name}: {package_name}")
    return None