Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from risk_biased.models.nn_blocks import ( | |
SequenceEncoderLSTM, | |
SequenceEncoderMLP, | |
SequenceEncoderMaskedLSTM, | |
) | |
from risk_biased.models.cvae_params import CVAEParams | |
from risk_biased.models.mlp import MLP | |
class MapEncoderNN(nn.Module): | |
"""MLP encoder neural network that encodes map objects. | |
Args: | |
params: dataclass defining the necessary parameters | |
""" | |
def __init__(self, params: CVAEParams) -> None: | |
super().__init__() | |
self._encoder = SequenceEncoderMLP( | |
params.map_state_dim, | |
params.hidden_dim, | |
params.num_hidden_layers, | |
params.max_size_lane, | |
params.is_mlp_residual, | |
) | |
def forward(self, map, mask_map): | |
"""Forward function encoding map object sequences of features into object features. | |
Args: | |
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects | |
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask | |
""" | |
encoded_map = self._encoder(map, mask_map) | |
return encoded_map | |