MilesCranmer commited on
Commit
88d93a1
1 Parent(s): 530ae99

feat: add helper for specifying dtype of jl Array

Browse files
Files changed (1) hide show
  1. pysr/julia_helpers.py +9 -2
pysr/julia_helpers.py CHANGED
@@ -32,10 +32,17 @@ def _load_cluster_manager(cluster_manager: str):
32
  return jl.seval(f"addprocs_{cluster_manager}")
33
 
34
 
35
- def jl_array(x):
36
  if x is None:
37
  return None
38
- return jl_convert(jl.Array, x)
 
 
 
 
 
 
 
39
 
40
 
41
  def jl_serialize(obj: Any) -> NDArray[np.uint8]:
 
32
  return jl.seval(f"addprocs_{cluster_manager}")
33
 
34
 
35
+ def jl_array(x, dtype=None):
36
  if x is None:
37
  return None
38
+ elif dtype is None:
39
+ return jl_convert(jl.Array, x)
40
+ else:
41
+ return jl_convert(jl.Array[dtype], x)
42
+
43
+
44
+ def jl_is_function(f):
45
+ return jl.seval("op -> op isa Function")(f)
46
 
47
 
48
  def jl_serialize(obj: Any) -> NDArray[np.uint8]: