jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from typing import Optional
from einops import rearrange
import torch
import torch.nn as nn
from risk_biased.models.cvae_params import CVAEParams
from risk_biased.models.nn_blocks import (
MCG,
MAB,
MHB,
SequenceEncoderLSTM,
SequenceEncoderMLP,
SequenceEncoderMaskedLSTM,
)
from risk_biased.models.latent_distributions import AbstractLatentDistribution
class BaseEncoderNN(nn.Module):
"""Base encoder neural network that defines the common functionality of encoders.
It should not be used directly but rather extended to define specific encoders.
Args:
params: dataclass defining the necessary parameters
num_steps: length of the input sequence
"""
def __init__(
self,
params: CVAEParams,
latent_dim: int,
num_steps: int,
) -> None:
super().__init__()
self.is_mlp_residual = params.is_mlp_residual
self.num_hidden_layers = params.num_hidden_layers
self.num_steps = params.num_steps
self.num_steps_future = params.num_steps_future
self.sequence_encoder_type = params.sequence_encoder_type
self.state_dim = params.state_dim
self.latent_dim = latent_dim
self.hidden_dim = params.hidden_dim
if params.sequence_encoder_type == "MLP":
self._agent_encoder = SequenceEncoderMLP(
params.state_dim,
params.hidden_dim,
params.num_hidden_layers,
num_steps,
params.is_mlp_residual,
)
elif params.sequence_encoder_type == "LSTM":
self._agent_encoder = SequenceEncoderLSTM(
params.state_dim, params.hidden_dim
)
elif params.sequence_encoder_type == "maskedLSTM":
self._agent_encoder = SequenceEncoderMaskedLSTM(
params.state_dim, params.hidden_dim
)
if params.interaction_type == "Attention" or params.interaction_type == "MAB":
self._interaction = MAB(
params.hidden_dim, params.num_attention_heads, params.num_blocks
)
elif (
params.interaction_type == "ContextGating"
or params.interaction_type == "MCG"
):
self._interaction = MCG(
params.hidden_dim,
params.mcg_dim_expansion,
params.mcg_num_layers,
params.num_blocks,
params.is_mlp_residual,
)
elif params.interaction_type == "Hybrid" or params.interaction_type == "MHB":
self._interaction = MHB(
params.hidden_dim,
params.num_attention_heads,
params.mcg_dim_expansion,
params.mcg_num_layers,
params.num_blocks,
params.is_mlp_residual,
)
else:
self._interaction = lambda x, *args, **kwargs: x
self._output_layer = nn.Linear(params.hidden_dim, self.latent_dim)
def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
raise NotImplementedError
def forward(
self,
x: torch.Tensor,
mask_x: torch.Tensor,
encoded_absolute: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
y: Optional[torch.Tensor] = None,
mask_y: Optional[torch.Tensor] = None,
x_ego: Optional[torch.Tensor] = None,
y_ego: Optional[torch.Tensor] = None,
offset: Optional[torch.Tensor] = None,
risk_level: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward function that encodes input tensors into an output tensor of dimension
latent_dim.
Args:
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects) tensor of bool mask
y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
x_ego: (batch_size, 1, num_steps, state_dim) ego history
y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
trajectories. Defaults to None.
Returns:
(batch_size, num_agents, latent_dim) output tensor
"""
h_agents = self.encode_agents(
x=x,
mask_x=mask_x,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
mask_agent = mask_x.any(-1)
h_agents = self._interaction(
h_agents, mask_agent, encoded_absolute, encoded_map, mask_map
)
return self._output_layer(h_agents)
class BiasedEncoderNN(BaseEncoderNN):
"""Biased encoder neural network that encodes past info and auxiliary input
into a biased distribution over the latent space.
Args:
params: dataclass defining the necessary parameters
num_steps: length of the input sequence
"""
def __init__(
self,
params: CVAEParams,
latent_dim: int,
num_steps: int,
) -> None:
super().__init__(params, latent_dim, num_steps)
self.condition_on_ego_future = params.condition_on_ego_future
if params.sequence_encoder_type == "MLP":
self._ego_encoder = SequenceEncoderMLP(
params.state_dim,
params.hidden_dim,
params.num_hidden_layers,
params.num_steps
+ params.num_steps_future * self.condition_on_ego_future,
params.is_mlp_residual,
)
elif params.sequence_encoder_type == "LSTM":
self._ego_encoder = SequenceEncoderLSTM(params.state_dim, params.hidden_dim)
elif params.sequence_encoder_type == "maskedLSTM":
self._ego_encoder = SequenceEncoderMaskedLSTM(
params.state_dim, params.hidden_dim
)
self._auxiliary_encode = nn.Linear(
params.hidden_dim + 1 + params.hidden_dim, params.hidden_dim
)
def biased_parameters(self, recurse: bool = True):
"""Get the parameters to be optimized when training to bias."""
yield from self.parameters(recurse)
def encode_agents(
self,
x: torch.Tensor,
mask_x: torch.Tensor,
*,
x_ego: torch.Tensor,
y_ego: torch.Tensor,
offset: torch.Tensor,
risk_level: torch.Tensor,
**kwargs,
):
"""Encode agent input and auxiliary input into a feature vector.
Args:
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
x_ego: (batch_size, 1, num_steps, state_dim) ego history
y_ego: (batch_size, 1, num_steps_future, state_dim) ego future
offset: (batch_size, num_agents, state_dim) offset position from ego.
risk_level: (batch_size, num_agents) tensor of risk levels desired for future
trajectories. Defaults to None.
Returns:
(batch_size, latent_dim) output tensor
"""
if self.condition_on_ego_future:
ego_tensor = torch.cat([x_ego, y_ego], dim=-2)
else:
ego_tensor = x_ego
risk_feature = ((risk_level - 0.5) * 10).exp().unsqueeze(-1)
mask_ego = torch.ones(
ego_tensor.shape[0],
offset.shape[1],
ego_tensor.shape[2],
device=ego_tensor.device,
)
batch_size, n_agents, dynamic_state_dim = offset.shape
state_dim = ego_tensor.shape[-1]
extended_offset = torch.cat(
(
offset,
torch.zeros(
batch_size,
n_agents,
state_dim - dynamic_state_dim,
device=offset.device,
),
),
dim=-1,
).unsqueeze(-2)
if extended_offset.shape[1] > 1:
ego_encoded = self._ego_encoder(
ego_tensor + extended_offset[:, :1] - extended_offset, mask_ego
)
else:
ego_encoded = self._ego_encoder(ego_tensor - extended_offset, mask_ego)
auxiliary_input = torch.cat((risk_feature, ego_encoded), -1)
h_agents = self._agent_encoder(x, mask_x)
h_agents = torch.cat([h_agents, auxiliary_input], dim=-1)
h_agents = self._auxiliary_encode(h_agents)
return h_agents
class InferenceEncoderNN(BaseEncoderNN):
"""Inference encoder neural network that encodes past info into the
inference distribution over the latent space.
Args:
params: dataclass defining the necessary parameters
num_steps: length of the input sequence
"""
def biaser_parameters(self, recurse: bool = True):
yield from []
def encode_agents(self, x: torch.Tensor, mask_x: torch.Tensor, *args, **kwargs):
h_agents = self._agent_encoder(x, mask_x)
return h_agents
class FutureEncoderNN(BaseEncoderNN):
"""Future encoder neural network that encodes past and future info into the
future-conditioned distribution over the latent space.
The future is not available at test time, this is only used for training.
Args:
params: dataclass defining the necessary parameters
num_steps: length of the input sequence
"""
def biaser_parameters(self, recurse: bool = True):
"""The future encoder is not optimized when training to bias."""
yield from []
def encode_agents(
self,
x: torch.Tensor,
mask_x: torch.Tensor,
*,
y: torch.Tensor,
mask_y: torch.Tensor,
**kwargs,
):
"""Encode agent input and future input into a feature vector.
Args:
x: (batch_size, num_agents, num_steps, state_dim) tensor of trajectory history
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
y: (batch_size, num_agents, num_steps_future, state_dim) future trajectory
mask_y: (batch_size, num_agents, num_steps_future) tensor of bool mask
"""
mask_traj = torch.cat([mask_x, mask_y], dim=-1)
h_agents = self._agent_encoder(torch.cat([x, y], dim=-2), mask_traj)
return h_agents
class CVAEEncoder(nn.Module):
"""Encoder architecture for conditional variational autoencoder
Args:
model: encoder neural network that transforms input tensors to an unsplitted latent output
latent_distribution_creator: Class that creates a latent distribution class for the latent space.
"""
def __init__(
self,
model: BaseEncoderNN,
latent_distribution_creator,
) -> None:
super().__init__()
self._model = model
self.latent_dim = model.latent_dim
self._latent_distribution_creator = latent_distribution_creator
def biased_parameters(self, recurse: bool = True):
yield from self._model.biased_parameters(recurse)
def forward(
self,
x: torch.Tensor,
mask_x: torch.Tensor,
encoded_absolute: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
y: Optional[torch.Tensor] = None,
mask_y: Optional[torch.Tensor] = None,
x_ego: Optional[torch.Tensor] = None,
y_ego: Optional[torch.Tensor] = None,
offset: Optional[torch.Tensor] = None,
risk_level: Optional[torch.Tensor] = None,
) -> AbstractLatentDistribution:
"""Forward function that encodes input tensors into an output tensor of dimension
latent_dim.
Args:
x: (batch_size, num_agents, num_steps, state_dim) tensor of history
mask_x: (batch_size, num_agents, num_steps) tensor of bool mask
encoded_absolute: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, map_feature_dim) tensor of encoded map objects
mask_map: (batch_size, num_objects) tensor of bool mask
y (optional): (batch_size, num_agents, num_steps_future, state_dim) tensor of future trajectory.
mask_y (optional): (batch_size, num_agents, num_steps_future) tensor of bool mask. Defaults to None.
x_ego (optional): (batch_size, 1, num_steps, state_dim) ego history
y_ego (optional): (batch_size, 1, num_steps_future, state_dim) ego future
offset (optional): (batch_size, num_agents, state_dim) offset position from ego.
risk_level (optional): (batch_size, num_agents) tensor of risk levels desired for future
trajectories. Defaults to None.
Returns:
Latent distribution representing the posterior over the latent variables given the input observations.
"""
latent_output = self._model(
x=x,
mask_x=mask_x,
encoded_absolute=encoded_absolute,
encoded_map=encoded_map,
mask_map=mask_map,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
latent_distribution = self._latent_distribution_creator(latent_output)
return latent_distribution