File size: 4,243 Bytes
e0f25ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from functools import partial

from jax2d.engine import recalculate_mass_and_inertia, recompute_global_joint_positions, select_shape
from kinetix.environment.env_state import EnvState, StaticEnvParams
from kinetix.pcg.pcg_state import PCGState
import jax
import jax.numpy as jnp


def _process_tied_together_shapes(pcg_state: PCGState, sampled_state: EnvState, static_params: StaticEnvParams):

    # Get the matrix of tied together positions. Since we vmap, we only want one entry active for any (i, j, k). Thus, we mask out some of the duplicate ones.
    tied = jnp.triu(pcg_state.tied_together & jnp.logical_not(jnp.eye(pcg_state.tied_together.shape[0], dtype=bool)))
    has_anything_in_column = tied.any(axis=0)
    tied = (
        tied * jnp.logical_not(has_anything_in_column)[:, None]
    )  # if there is something in a column, it means a previous one with a lower index has already been processed

    should_use_delta_positions = tied.any(axis=0)

    # This is the delta we have moved after sampling
    delta_positions = jnp.concatenate(
        [
            sampled_state.polygon.position - pcg_state.env_state.polygon.position,
            sampled_state.circle.position - pcg_state.env_state.circle.position,
        ]
    )

    def _get_effect_of_shape_i_on_all_others(item_index, item_row_of_what_is_tied):
        delta_pos = delta_positions[item_index]
        return jnp.arange(len(item_row_of_what_is_tied)), delta_pos[None] * item_row_of_what_is_tied[:, None]

    indices, positions = jax.vmap(_get_effect_of_shape_i_on_all_others, (0, 0))(jnp.arange(tied.shape[0]), tied)
    indices = indices.flatten()
    positions = positions.reshape(indices.shape[0], -1)

    default_positions = jnp.concatenate(
        [pcg_state.env_state.polygon.position, pcg_state.env_state.circle.position], axis=0
    )
    sampled_positions = jnp.concatenate([sampled_state.polygon.position, sampled_state.circle.position], axis=0)

    updated_positions = default_positions.at[indices].add(positions)
    # Use the deltas or the sampled positions
    positions = jnp.where(should_use_delta_positions[:, None], updated_positions, sampled_positions)

    sampled_state = sampled_state.replace(
        polygon=sampled_state.polygon.replace(position=positions[: static_params.num_polygons]),
        circle=sampled_state.circle.replace(position=positions[static_params.num_polygons :]),
    )
    return sampled_state


@partial(jax.jit, static_argnums=(3,))
def sample_pcg_state(rng, pcg_state: PCGState, params, static_params):
    def _pcg_fn(rng, main_val, max_val, mask):
        pcg_val = jax.random.uniform(rng, shape=main_val.shape) * (
            max_val.astype(float) - main_val.astype(float)
        ) + main_val.astype(float)
        if jnp.issubdtype(main_val.dtype, jnp.integer) or jnp.issubdtype(main_val.dtype, jnp.bool_):
            pcg_val = jnp.round(pcg_val)
        pcg_val = pcg_val.astype(main_val.dtype)
        new_val = jax.lax.select(mask.astype(bool), pcg_val, main_val)
        return new_val

    def _random_split_like_tree(rng, target):
        tree_def = jax.tree_structure(target)
        rngs = jax.random.split(rng, tree_def.num_leaves)
        return jax.tree_unflatten(tree_def, rngs)

    rng, _rng = jax.random.split(rng)
    rng_tree = _random_split_like_tree(_rng, pcg_state.env_state)

    sampled_state = jax.tree_util.tree_map(
        _pcg_fn, rng_tree, pcg_state.env_state, pcg_state.env_state_max, pcg_state.env_state_pcg_mask
    )

    sampled_state = _process_tied_together_shapes(pcg_state, sampled_state, static_params)

    sampled_state = recompute_global_joint_positions(sampled_state, static_params)

    env_state = recalculate_mass_and_inertia(
        sampled_state, static_params, sampled_state.polygon_densities, sampled_state.circle_densities
    )

    return env_state


def env_state_to_pcg_state(env_state: EnvState):
    N = env_state.polygon.active.shape[0] + env_state.circle.active.shape[0]
    pcg_state = PCGState(
        env_state=env_state,
        env_state_max=env_state,
        env_state_pcg_mask=jax.tree_util.tree_map(lambda x: jnp.zeros_like(x, dtype=bool), env_state),
        tied_together=jnp.zeros((N, N), dtype=bool),
    )

    return pcg_state