from typing import Callable import jax import jax.numpy as jnp from jaxtyping import Array def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable: def init(key, shape, dtype=dtype) -> Array: return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b) return init def linear_up(scale: float) -> Callable: def init(key, shape, dtype=jnp.float32) -> Array: assert shape[-2] == 2 keys = jax.random.split(key, 2) norm = jnp.pi * scale * ( jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5) theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1])) x = norm * jnp.cos(theta) y = norm * jnp.sin(theta) return jnp.concatenate([x, y], axis=-2).astype(dtype) return init