File size: 985 Bytes
581eeac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from dataclasses import field
import jax.numpy as jnp
from flax import struct

from jax2d.sim_state import SimState, SimParams, StaticSimParams, RigidBody, Joint, Thruster, CollisionManifold
from kinetix.environment.env_state import EnvState


@struct.dataclass
class PCGState:
    # Primary env state
    env_state: EnvState
    # The PCG mask.  If a value is truthy in this, then it is PCG not static
    env_state_pcg_mask: EnvState
    # In the case that a value is PCG, the env_state value is the min and this state represents the max
    env_state_max: EnvState

    tied_together: jnp.ndarray  # NxN matrix of booleans, where N is the number of shapes

    def __setstate__(self, state):
        if "tied_together" not in state:
            num_shapes = state["env_state"].polygon.active.shape[0] + state["env_state"].circle.active.shape[0]
            state["tied_together"] = jnp.zeros((num_shapes, num_shapes), dtype=bool)
        object.__setattr__(self, "__dict__", state)