jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from einops.layers.torch import Rearrange
from einops import rearrange, repeat
import torch
import torch.nn as nn
from risk_biased.models.multi_head_attention import MultiHeadAttention
from risk_biased.models.context_gating import ContextGating
from risk_biased.models.mlp import MLP
class SequenceEncoderMaskedLSTM(nn.Module):
"""MLP followed with a masked LSTM implementation with one layer.
Args:
input_dim : dimension of the input variable
h_dim : dimension of a hidden layer of MLP
"""
def __init__(self, input_dim: int, h_dim: int) -> None:
super().__init__()
self._group_objects = Rearrange("b o ... -> (b o) ...")
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
self._lstm = nn.LSTMCell(
input_size=h_dim, hidden_size=h_dim
) # expects(batch,seq,features)
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
"""Forward function for MapEncoder
Args:
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
Returns:
torch.Tensor: (batch_size, num_objects, output_dim) tensor
"""
batch_size, num_objects, seq_len, _ = input.shape
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
input = self._group_objects(input)
mask_input = self._group_objects(mask_input)
embedded_input = self._embed(input)
# One to many encoding of the input sequence with masking for missing points
mask_input = mask_input.float()
h = mask_input[:, 0, None] * embedded_input[:, 0, :] + (
1 - mask_input[:, 0, None]
) * repeat(self.h0, "b f -> (size b) f", size=batch_size * num_objects)
c = repeat(self.c0, "b f -> (size b) f", size=batch_size * num_objects)
for i in range(seq_len):
new_input = (
mask_input[:, i, None] * embedded_input[:, i, :]
+ (1 - mask_input[:, i, None]) * h
)
h, c = self._lstm(new_input, (h, c))
return split_objects(h)
class SequenceEncoderLSTM(nn.Module):
"""MLP followed with an LSTM with one layer.
Args:
input_dim : dimension of the input variable
h_dim : dimension of a hidden layer of MLP
"""
def __init__(self, input_dim: int, h_dim: int) -> None:
super().__init__()
self._group_objects = Rearrange("b o ... -> (b o) ...")
self._embed = nn.Linear(in_features=input_dim, out_features=h_dim)
self._lstm = nn.LSTM(
input_size=h_dim,
hidden_size=h_dim,
batch_first=True,
) # expects(batch,seq,features)
self.h0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
"""Forward function for MapEncoder
Args:
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
Returns:
torch.Tensor: (batch_size, num_objects, output_dim) tensor
"""
batch_size, num_objects, seq_len, _ = input.shape
split_objects = Rearrange("(b o) f -> b o f", b=batch_size, o=num_objects)
input = self._group_objects(input)
mask_input = self._group_objects(mask_input)
embedded_input = self._embed(input)
# One to many encoding of the input sequence with masking for missing points
mask_input = mask_input.float()
h = (
mask_input[:, 0, None] * embedded_input[:, 0, :]
+ (1 - mask_input[:, 0, None])
* repeat(
self.h0, "one f -> one size f", size=batch_size * num_objects
).contiguous()
)
c = repeat(
self.c0, "one f -> one size f", size=batch_size * num_objects
).contiguous()
_, (h, _) = self._lstm(embedded_input, (h, c))
# for i in range(seq_len):
# new_input = (
# mask_input[:, i, None] * embedded_input[:, i, :]
# + (1 - mask_input[:, i, None]) * h
# )
# h, c = self._lstm(new_input, (h, c))
return split_objects(h.squeeze(0))
class SequenceEncoderMLP(nn.Module):
"""MLP implementation.
Args:
input_dim : dimension of the input variable
h_dim : dimension of a hidden layer of MLP
num_layers: number of layers to use in the MLP
sequence_length: dimension of the input sequence
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self,
input_dim: int,
h_dim: int,
num_layers: int,
sequence_length: int,
is_mlp_residual: bool,
) -> None:
super().__init__()
self._mlp = MLP(
input_dim * sequence_length, h_dim, h_dim, num_layers, is_mlp_residual
)
def forward(self, input: torch.Tensor, mask_input: torch.Tensor) -> torch.Tensor:
"""Forward function for MapEncoder
Args:
input (torch.Tensor): (batch_size, num_objects, seq_len, input_dim) tensor
mask_input (torch.Tensor): (batch_size, num_objects, seq_len) bool tensor (True if data is good False if data is missing)
Returns:
torch.Tensor: (batch_size, num_objects, output_dim) tensor
"""
batch_size, num_objects, _, _ = input.shape
input = input * mask_input.unsqueeze(-1)
h = rearrange(input, "b o s f -> (b o) (s f)")
mask_input = rearrange(mask_input, "b o s -> (b o) s")
if h.shape[-1] == 0:
h = h.view(batch_size, 0, h.shape[0])
else:
h = self._mlp(h)
h = rearrange(h, "(b o) f -> b o f", b=batch_size, o=num_objects)
return h
class SequenceDecoderLSTM(nn.Module):
"""A one to many LSTM implementation with one layer.
Args:
h_dim : dimension of a hidden layer
"""
def __init__(self, h_dim: int) -> None:
super().__init__()
self._group_objects = Rearrange("b o f -> (b o) f")
self._lstm = nn.LSTM(input_size=h_dim, hidden_size=h_dim)
self._out_layer = nn.Linear(in_features=h_dim, out_features=h_dim)
self.c0 = nn.parameter.Parameter(torch.zeros(1, h_dim))
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
"""Forward function for MapEncoder
Args:
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
sequence_length: output sequence length to create
Returns:
torch.Tensor: (batch_size, num_objects, output_dim) tensor
"""
batch_size, num_objects, _ = input.shape
h = repeat(input, "b o f -> one (b o) f", one=1).contiguous()
c = repeat(
self.c0, "one f -> one size f", size=batch_size * num_objects
).contiguous()
seq_h = repeat(h, "one b f -> (one t) b f", t=sequence_length).contiguous()
h, (_, _) = self._lstm(seq_h, (h, c))
h = rearrange(h, "t (b o) f -> b o t f", b=batch_size, o=num_objects)
return self._out_layer(h)
class SequenceDecoderMLP(nn.Module):
"""A one to many MLP implementation.
Args:
h_dim : dimension of a hidden layer
num_layers: number of layers to use in the MLP
sequence_length: output sequence length to return
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self, h_dim: int, num_layers: int, sequence_length: int, is_mlp_residual: bool
) -> None:
super().__init__()
self._mlp = MLP(
h_dim, h_dim * sequence_length, h_dim, num_layers, is_mlp_residual
)
def forward(self, input: torch.Tensor, sequence_length: int) -> torch.Tensor:
"""Forward function for MapEncoder
Args:
input (torch.Tensor): (batch_size, num_objects, input_dim) tensor
sequence_length: output sequence length to create
Returns:
torch.Tensor: (batch_size, num_objects, output_dim) tensor
"""
batch_size, num_objects, _ = input.shape
h = rearrange(input, "b o f -> (b o) f")
h = self._mlp(h)
h = rearrange(
h, "(b o) (s f) -> b o s f", b=batch_size, o=num_objects, s=sequence_length
)
return h
class AttentionBlock(nn.Module):
"""Block performing agent-map cross attention->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
Args:
hidden_dim: feature dimension
num_attention_heads: number of attention heads to use
"""
def __init__(self, hidden_dim: int, num_attention_heads: int):
super().__init__()
self._num_attention_heads = num_attention_heads
self._agent_map_attention = MultiHeadAttention(
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
)
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
self._layer_norm1 = nn.LayerNorm(hidden_dim)
self._agent_agent_attention = MultiHeadAttention(
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
)
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
self._layer_norm2 = nn.LayerNorm(hidden_dim)
self._activation = nn.ReLU()
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning only the output (no attention matrix)
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
"""
# Check if map_info is available. If not, don't compute cross-attention with it
if mask_map.any():
mask_agent_map = torch.einsum("ba,bo->bao", mask_agents, mask_map)
h, _ = self._agent_map_attention(
encoded_agents + encoded_absolute_agents,
encoded_map,
encoded_map,
mask=mask_agent_map,
)
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
h = torch.sigmoid(self._lin1(h))
h = self._layer_norm1(encoded_agents + h)
else:
h = self._layer_norm1(encoded_agents)
h_res = h.clone()
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
h = h + encoded_absolute_agents
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
h = self._activation(self._lin2(h))
h = self._layer_norm2(h_res + h)
return h
class CG_block(nn.Module):
"""Block performing context gating agent-map
Args:
hidden_dim: feature dimension
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
num_layers: number of layers to use in the MLP for context encoding
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self,
hidden_dim: int,
dim_expansion: int,
num_layers: int,
is_mlp_residual: bool,
):
super().__init__()
self._agent_map = ContextGating(
hidden_dim,
hidden_dim * dim_expansion,
num_layers=num_layers,
is_mlp_residual=is_mlp_residual,
)
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
self._layer_norm1 = nn.LayerNorm(hidden_dim)
self._agent_agent = ContextGating(
hidden_dim, hidden_dim * dim_expansion, num_layers, is_mlp_residual
)
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
self._activation = nn.ReLU()
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
global_context: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning the output and global context
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
global_context: (batch_size, dim_context) tensor representing the global context
"""
# Check if map_info is available. If not, don't compute cross-interaction with it
if mask_map.any():
s, global_context = self._agent_map(
encoded_agents + encoded_absolute_agents, encoded_map, global_context
)
s = s * mask_agents.unsqueeze(-1)
s = self._activation(self._lin1(s))
s = self._layer_norm1(encoded_agents + s)
else:
s = self._layer_norm1(encoded_agents)
s = s + encoded_absolute_agents
s, global_context = self._agent_agent(s, s, global_context)
s = s * mask_agents.unsqueeze(-1)
s = self._lin2(s)
return s, global_context
class HybridBlock(nn.Module):
"""Block performing agent-map cross context_gating->ReLU(linear)->+residual->layer_norm->agent-agent attention->ReLU(linear)->+residual->layer_norm
Args:
hidden_dim: feature dimension
num_attention_heads: number of attention heads to use
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
num_layers: number of layers to use in the MLP for context encoding
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self,
hidden_dim: int,
num_attention_heads: int,
dim_expansion: int,
num_layers: int,
is_mlp_residual: bool,
):
super().__init__()
self._num_attention_heads = num_attention_heads
self._agent_map_cg = ContextGating(
hidden_dim,
hidden_dim * dim_expansion,
num_layers=num_layers,
is_mlp_residual=is_mlp_residual,
)
self._lin1 = nn.Linear(hidden_dim, hidden_dim)
self._layer_norm1 = nn.LayerNorm(hidden_dim)
self._agent_agent_attention = MultiHeadAttention(
hidden_dim, num_attention_heads, hidden_dim, hidden_dim
)
self._lin2 = nn.Linear(hidden_dim, hidden_dim)
self._layer_norm2 = nn.LayerNorm(hidden_dim)
self._activation = nn.ReLU()
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
global_context: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning the output and the context (no attention matrix)
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
global_context: (batch_size, dim_context) tensor representing the global context
"""
# Check if map_info is available. If not, don't compute cross-context gating with it
if mask_map.any():
# mask_agent_map = torch.logical_not(
# torch.einsum("ba,bo->bao", mask_agents, mask_map)
# )
h, global_context = self._agent_map_cg(
encoded_agents + encoded_absolute_agents, encoded_map, global_context
)
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
h = self._activation(self._lin1(h))
h = self._layer_norm1(encoded_agents + h)
else:
h = self._layer_norm1(encoded_agents)
h_res = h.clone()
agent_agent_mask = torch.einsum("ba,be->bae", mask_agents, mask_agents)
h = h + encoded_absolute_agents
h, _ = self._agent_agent_attention(h, h, h, mask=agent_agent_mask)
h = torch.masked_fill(h, torch.logical_not(mask_agents.unsqueeze(-1)), 0)
h = self._activation(self._lin2(h))
h = self._layer_norm2(h_res + h)
return h, global_context
class MCG(nn.Module):
"""Multiple context encoding blocks
Args:
hidden_dim: feature dimension
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
num_layers: number of layers to use in the MLP for context encoding
num_blocks: number of successive context encoding blocks to use in the module
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self,
hidden_dim: int,
dim_expansion: int,
num_layers: int,
num_blocks: int,
is_mlp_residual: bool,
):
super().__init__()
self.initial_global_context = nn.parameter.Parameter(
torch.ones(1, hidden_dim * dim_expansion)
)
list_cg = []
for i in range(num_blocks):
list_cg.append(
CG_block(hidden_dim, dim_expansion, num_layers, is_mlp_residual)
)
self.mcg = nn.ModuleList(list_cg)
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning only the output (no context)
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
"""
s = encoded_agents
c = self.initial_global_context
sum_s = s
sum_c = c
for i, cg in enumerate(self.mcg):
s_new, c_new = cg(
s, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
)
sum_s = sum_s + s_new
sum_c = sum_c + c_new
s = (sum_s / (i + 2)).clone()
c = (sum_c / (i + 2)).clone()
return s
class MAB(nn.Module):
"""Multiple Attention Blocks
Args:
hidden_dim: feature dimension
num_attention_heads: number of attention heads to use
num_blocks: number of successive blocks to use in the module.
"""
def __init__(
self,
hidden_dim: int,
num_attention_heads: int,
num_blocks: int,
):
super().__init__()
list_attention = []
for i in range(num_blocks):
list_attention.append(AttentionBlock(hidden_dim, num_attention_heads))
self.attention_blocks = nn.ModuleList(list_attention)
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning only the output (no attention matrix)
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
"""
h = encoded_agents
sum_h = h
for i, attention in enumerate(self.attention_blocks):
h_new = attention(
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map
)
sum_h = sum_h + h_new
h = (sum_h / (i + 2)).clone()
return h
class MHB(nn.Module):
"""Multiple Hybrid Blocks
Args:
hidden_dim: feature dimension
num_attention_heads: number of attention heads to use
dim_expansion: multiplicative factor on the hidden dimension for the global context representation
num_layers: number of layers to use in the MLP for context encoding
num_blocks: number of successive blocks to use in the module.
is_mlp_residual: set to True to add a linear transformation of the input to the output of the MLP
"""
def __init__(
self,
hidden_dim: int,
num_attention_heads: int,
dim_expansion: int,
num_layers: int,
num_blocks: int,
is_mlp_residual: bool,
):
super().__init__()
self.initial_global_context = nn.parameter.Parameter(
torch.ones(1, hidden_dim * dim_expansion)
)
list_hb = []
for i in range(num_blocks):
list_hb.append(
HybridBlock(
hidden_dim,
num_attention_heads,
dim_expansion,
num_layers,
is_mlp_residual,
)
)
self.hybrid_blocks = nn.ModuleList(list_hb)
def forward(
self,
encoded_agents: torch.Tensor,
mask_agents: torch.Tensor,
encoded_absolute_agents: torch.Tensor,
encoded_map: torch.Tensor,
mask_map: torch.Tensor,
) -> torch.Tensor:
"""Forward function of the block, returning only the output (no attention matrix nor context)
Args:
encoded_agents: (batch_size, num_agents, feature_size) tensor of the encoded agent tracks
mask_agents: (batch_size, num_agents) tensor True if agent track is good False if it is just padding
encoded_absolute_agents: (batch_size, num_agents, feature_size) tensor of the encoded absolute agent positions
encoded_map: (batch_size, num_objects, feature_size) tensor of the encoded map object features
mask_map: (batch_size, num_objects) tensor True if object is good False if it is just padding
"""
sum_h = encoded_agents
sum_c = self.initial_global_context
h = encoded_agents
c = self.initial_global_context
for i, hb in enumerate(self.hybrid_blocks):
h_new, c_new = hb(
h, mask_agents, encoded_absolute_agents, encoded_map, mask_map, c
)
sum_h = sum_h + h_new
sum_c = sum_c + c_new
h = (sum_h / (i + 2)).clone()
c = (sum_c / (i + 2)).clone()
return h