jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import torch
import torch.nn as nn
from risk_biased.models.nn_blocks import (
SequenceEncoderLSTM,
SequenceEncoderMLP,
SequenceEncoderMaskedLSTM,
)
from risk_biased.models.cvae_params import CVAEParams
from risk_biased.models.mlp import MLP
class MapEncoderNN(nn.Module):
"""MLP encoder neural network that encodes map objects.
Args:
params: dataclass defining the necessary parameters
"""
def __init__(self, params: CVAEParams) -> None:
super().__init__()
self._encoder = SequenceEncoderMLP(
params.map_state_dim,
params.hidden_dim,
params.num_hidden_layers,
params.max_size_lane,
params.is_mlp_residual,
)
def forward(self, map, mask_map):
"""Forward function encoding map object sequences of features into object features.
Args:
map: (batch_size, num_objects, object_sequence_length, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects, object_sequence_length) tensor of bool mask
"""
encoded_map = self._encoder(map, mask_map)
return encoded_map