File size: 3,096 Bytes
5769ee4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from dataclasses import dataclass

from mmcv import Config


@dataclass
class CVAEParams:
    """
    state_dim: Dimension of the state at each time step.
    map_state_dim: Dimension of the map point features at each position.
    num_steps: Number of time steps in the past trajectory input.
    num_steps_future: Number of time steps in the future trajectory output.
    latent_dim: Dimension of the latent space
    hidden_dim: Dimension of the hidden layers
    num_hidden_layers: Number of layers for each model, (encoder, decoder)
    is_mlp_residual: Set to True to add linear transformation of the input to output of the MLP
    interaction_type: Wether to use MCG, MAB, or MHB to handle interactions
    num_attention_heads: Number of attention heads to use in MHA blocks
    mcg_dim_expansion: Dimension expansion factor for the MCG global interaction space
    mcg_num_layers: Number of layers for the MLP MCG blocks
    num_blocks: Number of interaction blocks to use
    sequence_encoder_type: Type of sequence encoder maskedLSTM, LSTM, or MLP
    sequence_decoder_type: Type of sequence decoder maskedLSTM, LSTM, or MLP
    condition_on_ego_future: Wether to condition the biasing with the ego future or only the ego past
    latent_regularization: Weight of the latent regularization loss
    """

    dt: float
    state_dim: int
    dynamic_state_dim: int
    map_state_dim: int
    max_size_lane: int
    num_steps: int
    num_steps_future: int
    latent_dim: int
    hidden_dim: int
    num_hidden_layers: int
    is_mlp_residual: bool
    interaction_type: int
    num_attention_heads: int
    mcg_dim_expansion: int
    mcg_num_layers: int
    num_blocks: int
    sequence_encoder_type: str
    sequence_decoder_type: str
    condition_on_ego_future: bool
    latent_regularization: float
    risk_assymetry_factor: float
    num_vq: int
    latent_distribution: str

    @staticmethod
    def from_config(cfg: Config):
        return CVAEParams(
            dt=cfg.dt,
            state_dim=cfg.state_dim,
            dynamic_state_dim=cfg.dynamic_state_dim,
            map_state_dim=cfg.map_state_dim,
            max_size_lane=cfg.max_size_lane,
            num_steps=cfg.num_steps,
            num_steps_future=cfg.num_steps_future,
            latent_dim=cfg.latent_dim,
            hidden_dim=cfg.hidden_dim,
            num_hidden_layers=cfg.num_hidden_layers,
            is_mlp_residual=cfg.is_mlp_residual,
            interaction_type=cfg.interaction_type,
            mcg_dim_expansion=cfg.mcg_dim_expansion,
            mcg_num_layers=cfg.mcg_num_layers,
            num_blocks=cfg.num_blocks,
            num_attention_heads=cfg.num_attention_heads,
            sequence_encoder_type=cfg.sequence_encoder_type,
            sequence_decoder_type=cfg.sequence_decoder_type,
            condition_on_ego_future=cfg.condition_on_ego_future,
            latent_regularization=cfg.latent_regularization,
            risk_assymetry_factor=cfg.risk_assymetry_factor,
            num_vq=cfg.num_vq,
            latent_distribution=cfg.latent_distribution,
        )