|
from typing import Callable |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
from jaxtyping import Array |
|
|
|
|
|
def uniform_between(a: float, b: float, dtype=jnp.float32) -> Callable: |
|
def init(key, shape, dtype=dtype) -> Array: |
|
return jax.random.uniform(key, shape, dtype=dtype, minval=a, maxval=b) |
|
return init |
|
|
|
|
|
def linear_up(scale: float) -> Callable: |
|
def init(key, shape, dtype=jnp.float32) -> Array: |
|
assert shape[-2] == 2 |
|
keys = jax.random.split(key, 2) |
|
norm = jnp.pi * scale * ( |
|
jax.random.uniform(keys[0], shape=(1, shape[-1])) ** .5) |
|
theta = 2 * jnp.pi * jax.random.uniform(keys[1], shape=(1, shape[-1])) |
|
x = norm * jnp.cos(theta) |
|
y = norm * jnp.sin(theta) |
|
return jnp.concatenate([x, y], axis=-2).astype(dtype) |
|
return init |
|
|