Spaces:
Running
Running
File size: 3,096 Bytes
5769ee4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
from dataclasses import dataclass
from mmcv import Config
@dataclass
class CVAEParams:
"""
state_dim: Dimension of the state at each time step.
map_state_dim: Dimension of the map point features at each position.
num_steps: Number of time steps in the past trajectory input.
num_steps_future: Number of time steps in the future trajectory output.
latent_dim: Dimension of the latent space
hidden_dim: Dimension of the hidden layers
num_hidden_layers: Number of layers for each model, (encoder, decoder)
is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP
interaction_type: Wether to use MCG, MAB, or MHB to handle interactions
num_attention_heads: Number of attention heads to use in MHA blocks
mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space
mcg_num_layers: Number of layers for the MLP MCG blocks
num_blocks: Number of interaction blocks to use
sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP
sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP
condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past
latent_regularization: Weight of the latent regularization loss
"""
dt: float
state_dim: int
dynamic_state_dim: int
map_state_dim: int
max_size_lane: int
num_steps: int
num_steps_future: int
latent_dim: int
hidden_dim: int
num_hidden_layers: int
is_mlp_residual: bool
interaction_type: int
num_attention_heads: int
mcg_dim_expansion: int
mcg_num_layers: int
num_blocks: int
sequence_encoder_type: str
sequence_decoder_type: str
condition_on_ego_future: bool
latent_regularization: float
risk_assymetry_factor: float
num_vq: int
latent_distribution: str
@staticmethod
def from_config(cfg: Config):
return CVAEParams(
dt=cfg.dt,
state_dim=cfg.state_dim,
dynamic_state_dim=cfg.dynamic_state_dim,
map_state_dim=cfg.map_state_dim,
max_size_lane=cfg.max_size_lane,
num_steps=cfg.num_steps,
num_steps_future=cfg.num_steps_future,
latent_dim=cfg.latent_dim,
hidden_dim=cfg.hidden_dim,
num_hidden_layers=cfg.num_hidden_layers,
is_mlp_residual=cfg.is_mlp_residual,
interaction_type=cfg.interaction_type,
mcg_dim_expansion=cfg.mcg_dim_expansion,
mcg_num_layers=cfg.mcg_num_layers,
num_blocks=cfg.num_blocks,
num_attention_heads=cfg.num_attention_heads,
sequence_encoder_type=cfg.sequence_encoder_type,
sequence_decoder_type=cfg.sequence_decoder_type,
condition_on_ego_future=cfg.condition_on_ego_future,
latent_regularization=cfg.latent_regularization,
risk_assymetry_factor=cfg.risk_assymetry_factor,
num_vq=cfg.num_vq,
latent_distribution=cfg.latent_distribution,
)
|