Spaces:
Runtime error
Runtime error
import math | |
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
from jax2d.engine import PhysicsEngine, calculate_collision_matrix, recalculate_mass_and_inertia, select_shape | |
from jax2d.sim_state import RigidBody, Thruster | |
from kinetix.environment.env_state import EnvParams, EnvState, StaticEnvParams | |
def sample_dimensions(rng, static_env_params: StaticEnvParams, is_rect: bool, ued_params, max_shape_size=None): | |
if max_shape_size is None: | |
max_shape_size = static_env_params.max_shape_size | |
# Returns (half_dimensions, radius) | |
rng, _rng = jax.random.split(rng) | |
# Don't want overly small shapes | |
min_rect_size = 0.05 | |
min_circle_size = 0.1 | |
cap_rect = max_shape_size / 2.0 / jnp.sqrt(2.0) | |
cap_circ = max_shape_size / 2.0 * ued_params.circle_max_size_coeff | |
half_dimensions = ( | |
jax.lax.select(is_rect, jax.random.uniform(_rng, shape=(2,)), jnp.zeros(2, dtype=jnp.float32)) | |
* (cap_rect - min_rect_size) | |
+ min_rect_size | |
) | |
rng, _rng, __rng = jax.random.split(rng, 3) | |
dim_scale = ( | |
jnp.ones(2) | |
.at[jax.random.randint(_rng, shape=(), minval=0, maxval=2)] | |
.set( | |
jax.lax.select( | |
jax.random.uniform(__rng) < ued_params.large_rect_dim_chance, ued_params.large_rect_dim_scale, 1.0 | |
) | |
) | |
) | |
half_dimensions *= dim_scale | |
vertices = jnp.array( | |
[ | |
half_dimensions * jnp.array([1, 1]), | |
half_dimensions * jnp.array([1, -1]), | |
half_dimensions * jnp.array([-1, -1]), | |
half_dimensions * jnp.array([-1, 1]), | |
] | |
) | |
rng, _rng = jax.random.split(rng) | |
radius = ( | |
jax.lax.select(is_rect, jnp.zeros((), dtype=jnp.float32), jax.random.uniform(_rng, shape=())) | |
* (cap_circ - min_circle_size) | |
+ min_circle_size | |
) | |
return vertices, half_dimensions, radius | |
def count_roles(state: EnvState, static_env_params: StaticEnvParams, role: int, include_static_polys=True) -> int: | |
active_to_use = state.polygon.active | |
if not include_static_polys: | |
active_to_use = active_to_use.at[: static_env_params.num_static_fixated_polys].set(False) | |
return ((state.polygon_shape_roles == role) * active_to_use).sum() + ( | |
(state.circle_shape_roles == role) * state.circle.active | |
).sum() | |
def random_position_on_triangle(rng, vertices): | |
verts = vertices[:3] | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
f1 = jax.random.uniform(_rng) | |
f2 = jax.random.uniform(_rng2) | |
# https://www.reddit.com/r/godot/comments/mqp29g/how_do_i_get_a_random_position_inside_a_collision/ | |
return verts[0] + jnp.sqrt(f1) * (-verts[0] + verts[1] + f2 * (verts[2] - verts[1])) | |
def random_position_on_rectangle(rng, vertices): | |
verts = vertices[:4] | |
rng, _rng, _rng2 = jax.random.split(rng, 3) | |
f1 = jax.random.uniform(_rng) | |
f2 = jax.random.uniform(_rng2) | |
min_x, max_x = jnp.min(verts[:, 0]), jnp.max(verts[:, 0]) | |
min_y, max_y = jnp.min(verts[:, 1]), jnp.max(verts[:, 1]) | |
random_x_pos = min_x + f1 * (max_x - min_x) | |
random_y_pos = min_y + f2 * (max_y - min_y) | |
return jnp.array([random_x_pos, random_y_pos]) | |
def random_position_on_polygon(rng, vertices, n_vertices, static_env_params: StaticEnvParams): | |
assert static_env_params.max_polygon_vertices <= 4, "Only supports up to 4 vertices" | |
return jax.lax.select( | |
n_vertices <= 3, random_position_on_triangle(rng, vertices), random_position_on_rectangle(rng, vertices) | |
) | |
def random_position_on_circle(rng, radius, on_centre_chance): | |
rngs = jax.random.split(rng, 3) | |
on_centre = jax.random.uniform(rngs[0]) < on_centre_chance | |
local_joint_position_circle_theta = jax.random.uniform(rngs[1], shape=()) * 2 * math.pi | |
local_joint_position_circle_r = jax.random.uniform(rngs[2], shape=()) * radius | |
local_joint_position_circle = jnp.array( | |
[ | |
local_joint_position_circle_r * jnp.cos(local_joint_position_circle_theta), | |
local_joint_position_circle_r * jnp.sin(local_joint_position_circle_theta), | |
] | |
) | |
return jax.lax.select(on_centre, jnp.array([0.0, 0.0]), local_joint_position_circle) | |
def get_role(rng, state: EnvState, static_env_params: StaticEnvParams, initial_p=None) -> int: | |
if initial_p is None: | |
initial_p = jnp.array([1.0, 1.0, 1.0, 1.0]) | |
needs_ball = count_roles(state, static_env_params, 1) == 0 | |
needs_goal = count_roles(state, static_env_params, 2) == 0 | |
needs_lava = count_roles(state, static_env_params, 3) == 0 | |
# always put goal/ball first. | |
prob_of_something_else = (needs_ball == 0) & (needs_goal == 0) | |
p = initial_p * jnp.array( | |
[prob_of_something_else, needs_ball, needs_goal, prob_of_something_else * needs_lava / 3] | |
) # This ensures we cannot more than one ball or goal. | |
return jax.random.choice(rng, jnp.array([0, 1, 2, 3]), p=p) | |
def is_space_for_shape(state: EnvState): | |
return jnp.logical_not(jnp.concatenate([state.polygon.active, state.circle.active])).sum() > 0 | |
def is_space_for_joint(state: EnvState): | |
return jnp.logical_not(state.joint.active).sum() > 0 | |
def are_there_shapes_present(state: EnvState, static_env_params: StaticEnvParams): | |
m = ( | |
jnp.concatenate([state.polygon.active, state.circle.active]) | |
.at[: static_env_params.num_static_fixated_polys] | |
.set(False) | |
) | |
return m.sum() > 0 | |
def add_rigidbody_to_state( | |
state: EnvState, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
position: jnp.ndarray, | |
vertices: jnp.ndarray, | |
n_vertices: int, | |
radius: float, | |
shape_role: int, | |
density: float = 1, | |
is_circle: bool = False, | |
): | |
new_rigid_body = RigidBody( | |
position=position, | |
velocity=jnp.array([0.0, 0.0]), | |
inverse_mass=1.0, | |
inverse_inertia=1.0, | |
rotation=0.0, | |
angular_velocity=0.0, | |
radius=radius, | |
active=True, | |
friction=1.0, | |
vertices=vertices, | |
n_vertices=n_vertices, | |
collision_mode=1, | |
restitution=0.0, | |
) | |
if is_circle: | |
actives = state.circle.active | |
else: | |
actives = state.polygon.active | |
idx = jnp.argmin(actives) | |
def noop(state): | |
return state | |
def replace(state): | |
add_func = lambda all, new: all.at[idx].set(new) | |
if is_circle: | |
state = state.replace( | |
circle=jax.tree.map(add_func, state.circle, new_rigid_body), | |
circle_densities=state.circle_densities.at[idx].set(density), | |
circle_shape_roles=state.circle_shape_roles.at[idx].set(shape_role), | |
) | |
else: | |
state = state.replace( | |
polygon=jax.tree.map(add_func, state.polygon, new_rigid_body), | |
polygon_densities=state.polygon_densities.at[idx].set(density), | |
polygon_shape_roles=state.polygon_shape_roles.at[idx].set(shape_role), | |
) | |
state = state.replace( | |
collision_matrix=calculate_collision_matrix(static_env_params, state.joint), | |
) | |
state = recalculate_mass_and_inertia(state, static_env_params, state.polygon_densities, state.circle_densities) | |
return state | |
return jax.lax.cond(jnp.logical_not(actives).sum() > 0, replace, noop, state) | |
def rectangle_vertices(half_dim): | |
return jnp.array( | |
[ | |
half_dim * jnp.array([1, 1]), | |
half_dim * jnp.array([1, -1]), | |
half_dim * jnp.array([-1, -1]), | |
half_dim * jnp.array([-1, 1]), | |
] | |
) | |
# More Manual Control | |
def add_rectangle_to_state( | |
state: EnvState, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
position: jnp.ndarray, | |
width: float, | |
height: float, | |
shape_role: int, | |
density: float = 1, | |
): | |
return add_rigidbody_to_state( | |
state, | |
env_params, | |
static_env_params, | |
position, | |
rectangle_vertices(jnp.array([width, height]) / 2), | |
4, | |
0.0, | |
shape_role, | |
density, | |
is_circle=False, | |
) | |
def add_circle_to_state( | |
state: EnvState, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
position: jnp.ndarray, | |
radius: float, | |
shape_role: int, | |
density: float = 1, | |
): | |
return add_rigidbody_to_state( | |
state, | |
env_params, | |
static_env_params, | |
position, | |
jnp.array([0.0, 0.0]), | |
0, | |
radius, | |
shape_role, | |
density, | |
is_circle=True, | |
) | |
def add_thruster_to_object( | |
state: EnvState, | |
env_params: EnvParams, | |
static_env_params: StaticEnvParams, | |
shape_index: int, | |
rotation: float, | |
colour: int, | |
thruster_power_multiplier: float, | |
): | |
def dummy(state): | |
return state | |
def do_add(state: EnvState): | |
thruster_idx = jnp.argmin(state.thruster.active) | |
shape = select_shape(state, shape_index, static_env_params) | |
thruster = Thruster( | |
object_index=shape_index, | |
active=True, | |
relative_position=jnp.array([0.0, 0.0]), # a bit of a hack but reasonable. | |
rotation=rotation, | |
power=1.0 / jax.lax.select(shape.inverse_mass == 0, 1.0, shape.inverse_mass) * thruster_power_multiplier, | |
global_position=select_shape(state, shape_index, static_env_params).position, | |
) | |
state = state.replace( | |
thruster=jax.tree_map(lambda y, x: y.at[thruster_idx].set(x), state.thruster, thruster), | |
thruster_bindings=state.thruster_bindings.at[thruster_idx].set(colour), | |
) | |
return state | |
return jax.lax.cond( | |
(select_shape(state, shape_index, static_env_params).active) | |
& (jnp.logical_not(state.thruster.active).sum() > 0), | |
do_add, | |
dummy, | |
state, | |
) | |
def make_velocities_zero(state: EnvState): | |
def inner(state): | |
return state.replace( | |
polygon=state.polygon.replace( | |
angular_velocity=state.polygon.angular_velocity * 0, | |
velocity=state.polygon.velocity * 0, | |
), | |
circle=state.circle.replace( | |
angular_velocity=state.circle.angular_velocity * 0, | |
velocity=state.circle.velocity * 0, | |
), | |
) | |
return inner(state) | |
def make_do_dummy_step( | |
params: EnvParams, static_sim_params: StaticEnvParams, zero_collisions=True, zero_velocities=True | |
): | |
env = PhysicsEngine(static_sim_params) | |
def _step_fn(state): | |
state, _ = env.step(state, params, jnp.zeros((static_sim_params.num_joints + static_sim_params.num_thrusters,))) | |
return state | |
def do_dummy_step(state: EnvState) -> EnvState: | |
rng = jax.random.PRNGKey(0) | |
og_col = state.collision_matrix | |
g = state.gravity | |
state = state.replace( | |
collision_matrix=state.collision_matrix & (not zero_collisions), gravity=state.gravity * 0 | |
) | |
state = _step_fn(state) | |
state = state.replace(gravity=g, collision_matrix=og_col) | |
if zero_velocities: | |
state = make_velocities_zero(state) | |
return state | |
return do_dummy_step | |