File size: 808 Bytes
1eb87a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from typing import Callable
import jax
import jax.numpy as jnp
from jaxtyping import Array
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable:
def init(key, shape, dtype=dtype) -> Array:
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b)
return init
def linear_up(scale: float) -> Callable:
def init(key, shape, dtype=jnp.float32) -> Array:
assert shape[-2] == 2
keys = jax.random.split(key, 2)
norm = jnp.pi * scale * (
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5)
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1]))
x = norm * jnp.cos(theta)
y = norm * jnp.sin(theta)
return jnp.concatenate([x, y], axis=-2).astype(dtype)
return init
|