mohammed-aljafry's picture
Upload folder using huggingface_hub
c89800e verified
raw
history blame
40 kB
# Cell 1: Clone Repository and Install Dependencies
!git clone https://github.com/opendilab/InterFuser.git
!pip install timm
import math
import copy
import logging
import sys
from collections import OrderedDict
from functools import partial
from typing import Optional, List
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# import wandb # Import wandb
# Add InterFuser to Python path
sys.path.append('/content/InterFuser')
# --- W&B Login ---
# You might need to provide your API key when running this in Colab
# try:
# wandb.login()
# except Exception as e:
# print(f"Wandb login failed. Please ensure you have provided your API key. Error: {e}")
# Import specific modules from the cloned repository (adjust paths if needed)
try:
# Assuming the structure within the cloned repo is InterFuser/interfuser/...
from InterFuser.interfuser.timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple
from InterFuser.interfuser.timm.models.registry import register_model
# Note: The original code seemed to have local imports like '.resnet',
# these need to be adjusted based on the actual file structure after cloning.
# Using the direct import path assuming it's available after appending '/content'
from InterFuser.interfuser.timm.models.resnet import resnet26d, resnet50d, resnet18d, resnet26, resnet50, resnet101d
except ImportError as e:
print(f"Error importing from InterFuser repository: {e}")
print("Please ensure the repository structure is correct and accessible.")
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import json
from pathlib import Path
from torchvision import transforms
import os
import tqdm
_logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO) # Show logs, including warnings
class HybridEmbed(nn.Module):
def __init__(
self,
backbone,
img_size=224,
patch_size=1,
feature_size=None,
in_chans=3,
embed_dim=768,
):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1] # last feature if backbone outputs list/tuple of features
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, "feature_info"):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1] # last feature if backbone outputs list/tuple of features
x = self.proj(x)
global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
return x, global_x
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
bs, _, h, w = x.shape
not_mask = torch.ones((bs, h, w), device=x.device)
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
pos=pos,
)
if self.norm is not None:
output = self.norm(output)
return output
class SpatialSoftmax(nn.Module):
def __init__(self, height, width, channel, temperature=None, data_format="NCHW"):
super().__init__()
self.data_format = data_format
self.height = height
self.width = width
self.channel = channel
if temperature:
self.temperature = Parameter(torch.ones(1) * temperature)
else:
self.temperature = 1.0
pos_x, pos_y = np.meshgrid(
np.linspace(-1.0, 1.0, self.height), np.linspace(-1.0, 1.0, self.width)
)
pos_x = torch.from_numpy(pos_x.reshape(self.height * self.width)).float()
pos_y = torch.from_numpy(pos_y.reshape(self.height * self.width)).float()
self.register_buffer("pos_x", pos_x)
self.register_buffer("pos_y", pos_y)
def forward(self, feature):
# Output:
# (N, C*2) x_0 y_0 ...
if self.data_format == "NHWC":
feature = (
feature.transpose(1, 3)
.tranpose(2, 3)
.view(-1, self.height * self.width)
)
else:
feature = feature.view(-1, self.height * self.width)
weight = F.softmax(feature / self.temperature, dim=-1)
expected_x = torch.sum(
torch.autograd.Variable(self.pos_x) * weight, dim=1, keepdim=True
)
expected_y = torch.sum(
torch.autograd.Variable(self.pos_y) * weight, dim=1, keepdim=True
)
expected_xy = torch.cat([expected_x, expected_y], 1)
feature_keypoints = expected_xy.view(-1, self.channel, 2)
feature_keypoints[:, :, 1] = (feature_keypoints[:, :, 1] - 1) * 12
feature_keypoints[:, :, 0] = feature_keypoints[:, :, 0] * 12
return feature_keypoints
class MultiPath_Generator(nn.Module):
def __init__(self, in_channel, embed_dim, out_channel):
super().__init__()
self.spatial_softmax = SpatialSoftmax(100, 100, out_channel)
self.tconv0 = nn.Sequential(
nn.ConvTranspose2d(in_channel, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
)
self.tconv1 = nn.Sequential(
nn.ConvTranspose2d(256, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
)
self.tconv2 = nn.Sequential(
nn.ConvTranspose2d(256, 192, 4, 2, 1, bias=False),
nn.BatchNorm2d(192),
nn.ReLU(True),
)
self.tconv3 = nn.Sequential(
nn.ConvTranspose2d(192, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
)
self.tconv4_list = torch.nn.ModuleList(
[
nn.Sequential(
nn.ConvTranspose2d(64, out_channel, 8, 2, 3, bias=False),
nn.Tanh(),
)
for _ in range(6)
]
)
self.upsample = nn.Upsample(size=(50, 50), mode="bilinear")
def forward(self, x, measurements):
mask = measurements[:, :6]
mask = mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, 1, 100, 100)
velocity = measurements[:, 6:7].unsqueeze(-1).unsqueeze(-1)
velocity = velocity.repeat(1, 32, 2, 2)
n, d, c = x.shape
x = x.transpose(1, 2)
x = x.view(n, -1, 2, 2)
x = torch.cat([x, velocity], dim=1)
x = self.tconv0(x)
x = self.tconv1(x)
x = self.tconv2(x)
x = self.tconv3(x)
x = self.upsample(x)
xs = []
for i in range(6):
xt = self.tconv4_list[i](x)
xs.append(xt)
xs = torch.stack(xs, dim=1)
x = torch.sum(xs * mask, dim=1)
x = self.spatial_softmax(x)
return x
class LinearWaypointsPredictor(nn.Module):
def __init__(self, input_dim, cumsum=True):
super().__init__()
self.cumsum = cumsum
self.rank_embed = nn.Parameter(torch.zeros(1, 10, input_dim))
self.head_fc1_list = nn.ModuleList([nn.Linear(input_dim, 64) for _ in range(6)])
self.head_relu = nn.ReLU(inplace=True)
self.head_fc2_list = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
def forward(self, x, measurements):
# input shape: n 10 embed_dim
bs, n, dim = x.shape
x = x + self.rank_embed
x = x.reshape(-1, dim)
mask = measurements[:, :6]
mask = torch.unsqueeze(mask, -1).repeat(n, 1, 2)
rs = []
for i in range(6):
res = self.head_fc1_list[i](x)
res = self.head_relu(res)
res = self.head_fc2_list[i](res)
rs.append(res)
rs = torch.stack(rs, 1)
x = torch.sum(rs * mask, dim=1)
x = x.view(bs, n, 2)
if self.cumsum:
x = torch.cumsum(x, 1)
return x
class GRUWaypointsPredictor(nn.Module):
def __init__(self, input_dim, waypoints=10):
super().__init__()
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
self.encoder = nn.Linear(2, 64)
self.decoder = nn.Linear(64, 2)
self.waypoints = waypoints
def forward(self, x, target_point):
bs = x.shape[0]
z = self.encoder(target_point).unsqueeze(0)
output, _ = self.gru(x, z)
output = output.reshape(bs * self.waypoints, -1)
output = self.decoder(output).reshape(bs, self.waypoints, 2)
output = torch.cumsum(output, 1)
return output
class GRUWaypointsPredictorWithCommand(nn.Module):
def __init__(self, input_dim, waypoints=10):
super().__init__()
# self.gru = torch.nn.GRUCell(input_size=input_dim, hidden_size=64)
self.grus = nn.ModuleList([torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True) for _ in range(6)])
self.encoder = nn.Linear(2, 64)
self.decoders = nn.ModuleList([nn.Linear(64, 2) for _ in range(6)])
self.waypoints = waypoints
def forward(self, x, target_point, measurements):
bs, n, dim = x.shape
mask = measurements[:, :6, None, None]
mask = mask.repeat(1, 1, self.waypoints, 2)
z = self.encoder(target_point).unsqueeze(0)
outputs = []
for i in range(6):
output, _ = self.grus[i](x, z)
output = output.reshape(bs * self.waypoints, -1)
output = self.decoders[i](output).reshape(bs, self.waypoints, 2)
output = torch.cumsum(output, 1)
outputs.append(output)
outputs = torch.stack(outputs, 1)
output = torch.sum(outputs * mask, dim=1)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=nn.ReLU(),
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation()
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=nn.ReLU(),
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation()
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
def build_attn_mask(mask_type):
mask = torch.ones((151, 151), dtype=torch.bool).cuda()
if mask_type == "seperate_all":
mask[:50, :50] = False
mask[50:67, 50:67] = False
mask[67:84, 67:84] = False
mask[84:101, 84:101] = False
mask[101:151, 101:151] = False
elif mask_type == "seperate_view":
mask[:50, :50] = False
mask[50:67, 50:67] = False
mask[67:84, 67:84] = False
mask[84:101, 84:101] = False
mask[101:151, :] = False
mask[:, 101:151] = False
return mask
class Interfuser(nn.Module):
def __init__(
self,
img_size=224,
multi_view_img_size=112,
patch_size=8,
in_chans=3,
embed_dim=768,
enc_depth=6,
dec_depth=6,
dim_feedforward=2048,
normalize_before=False,
rgb_backbone_name="r26",
lidar_backbone_name="r26",
num_heads=8,
norm_layer=None,
dropout=0.1,
end2end=False,
direct_concat=True,
separate_view_attention=False,
separate_all_attention=False,
act_layer=None,
weight_init="",
freeze_num=-1,
with_lidar=False,
with_right_left_sensors=True,
with_center_sensor=False,
traffic_pred_head_type="det",
waypoints_pred_head="heatmap",
reverse_pos=True,
use_different_backbone=False,
use_view_embed=True,
use_mmad_pretrain=None,
):
super().__init__()
self.traffic_pred_head_type = traffic_pred_head_type
self.num_features = (
self.embed_dim
) = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.reverse_pos = reverse_pos
self.waypoints_pred_head = waypoints_pred_head
self.with_lidar = with_lidar
self.with_right_left_sensors = with_right_left_sensors
self.with_center_sensor = with_center_sensor
self.direct_concat = direct_concat
self.separate_view_attention = separate_view_attention
self.separate_all_attention = separate_all_attention
self.end2end = end2end
self.use_view_embed = use_view_embed
if self.direct_concat:
in_chans = in_chans * 4
self.with_center_sensor = False
self.with_right_left_sensors = False
if self.separate_view_attention:
self.attn_mask = build_attn_mask("seperate_view")
elif self.separate_all_attention:
self.attn_mask = build_attn_mask("seperate_all")
else:
self.attn_mask = None
if use_different_backbone:
if rgb_backbone_name == "r50":
self.rgb_backbone = resnet50d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif rgb_backbone_name == "r26":
self.rgb_backbone = resnet26d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif rgb_backbone_name == "r18":
self.rgb_backbone = resnet18d(
pretrained=True,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
if lidar_backbone_name == "r50":
self.lidar_backbone = resnet50d(
pretrained=False,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif lidar_backbone_name == "r26":
self.lidar_backbone = resnet26d(
pretrained=False,
in_chans=in_chans,
features_only=True,
out_indices=[4],
)
elif lidar_backbone_name == "r18":
self.lidar_backbone = resnet18d(
pretrained=False, in_chans=3, features_only=True, out_indices=[4]
)
rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
if use_mmad_pretrain:
params = torch.load(use_mmad_pretrain)["state_dict"]
updated_params = OrderedDict()
for key in params:
if "backbone" in key:
updated_params[key.replace("backbone.", "")] = params[key]
self.rgb_backbone.load_state_dict(updated_params)
self.rgb_patch_embed = rgb_embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.lidar_patch_embed = lidar_embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=3,
embed_dim=embed_dim,
)
else:
if rgb_backbone_name == "r50":
self.rgb_backbone = resnet50d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r101":
self.rgb_backbone = resnet101d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r26":
self.rgb_backbone = resnet26d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
elif rgb_backbone_name == "r18":
self.rgb_backbone = resnet18d(
pretrained=True, in_chans=3, features_only=True, out_indices=[4]
)
embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
self.rgb_patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.lidar_patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
if self.end2end:
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 4))
self.query_embed = nn.Parameter(torch.zeros(4, 1, embed_dim))
elif self.waypoints_pred_head == "heatmap":
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
self.query_embed = nn.Parameter(torch.zeros(400 + 5, 1, embed_dim))
else:
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
if self.end2end:
self.waypoints_generator = GRUWaypointsPredictor(embed_dim, 4)
elif self.waypoints_pred_head == "heatmap":
self.waypoints_generator = MultiPath_Generator(
embed_dim + 32, embed_dim, 10
)
elif self.waypoints_pred_head == "gru":
self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
elif self.waypoints_pred_head == "gru-command":
self.waypoints_generator = GRUWaypointsPredictorWithCommand(embed_dim)
elif self.waypoints_pred_head == "linear":
self.waypoints_generator = LinearWaypointsPredictor(embed_dim)
elif self.waypoints_pred_head == "linear-sum":
self.waypoints_generator = LinearWaypointsPredictor(embed_dim, cumsum=True)
self.junction_pred_head = nn.Linear(embed_dim, 2)
self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
self.stop_sign_head = nn.Linear(embed_dim, 2)
if self.traffic_pred_head_type == "det":
self.traffic_pred_head = nn.Sequential(
*[
nn.Linear(embed_dim + 32, 64),
nn.ReLU(),
nn.Linear(64, 7),
nn.Sigmoid(),
]
)
elif self.traffic_pred_head_type == "seg":
self.traffic_pred_head = nn.Sequential(
*[nn.Linear(embed_dim, 64), nn.ReLU(), nn.Linear(64, 1), nn.Sigmoid()]
)
self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
encoder_layer = TransformerEncoderLayer(
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
)
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
decoder_layer = TransformerDecoderLayer(
embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before
)
decoder_norm = nn.LayerNorm(embed_dim)
self.decoder = TransformerDecoder(
decoder_layer, dec_depth, decoder_norm, return_intermediate=False
)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.global_embed)
nn.init.uniform_(self.view_embed)
nn.init.uniform_(self.query_embed)
nn.init.uniform_(self.query_pos_embed)
def forward_features(
self,
front_image,
left_image,
right_image,
front_center_image,
lidar,
measurements,
):
features = []
# Front view processing
front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
if self.use_view_embed:
front_image_token = (
front_image_token
+ self.view_embed[:, :, 0:1, :]
+ self.position_encoding(front_image_token)
)
else:
front_image_token = front_image_token + self.position_encoding(
front_image_token
)
front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
front_image_token_global = (
front_image_token_global
+ self.view_embed[:, :, 0, :]
+ self.global_embed[:, :, 0:1]
)
front_image_token_global = front_image_token_global.permute(2, 0, 1)
features.extend([front_image_token, front_image_token_global])
if self.with_right_left_sensors:
# Left view processing
left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
if self.use_view_embed:
left_image_token = (
left_image_token
+ self.view_embed[:, :, 1:2, :]
+ self.position_encoding(left_image_token)
)
else:
left_image_token = left_image_token + self.position_encoding(
left_image_token
)
left_image_token = left_image_token.flatten(2).permute(2, 0, 1)
left_image_token_global = (
left_image_token_global
+ self.view_embed[:, :, 1, :]
+ self.global_embed[:, :, 1:2]
)
left_image_token_global = left_image_token_global.permute(2, 0, 1)
# Right view processing
right_image_token, right_image_token_global = self.rgb_patch_embed(
right_image
)
if self.use_view_embed:
right_image_token = (
right_image_token
+ self.view_embed[:, :, 2:3, :]
+ self.position_encoding(right_image_token)
)
else:
right_image_token = right_image_token + self.position_encoding(
right_image_token
)
right_image_token = right_image_token.flatten(2).permute(2, 0, 1)
right_image_token_global = (
right_image_token_global
+ self.view_embed[:, :, 2, :]
+ self.global_embed[:, :, 2:3]
)
right_image_token_global = right_image_token_global.permute(2, 0, 1)
features.extend(
[
left_image_token,
left_image_token_global,
right_image_token,
right_image_token_global,
]
)
if self.with_center_sensor:
# Front center view processing
(
front_center_image_token,
front_center_image_token_global,
) = self.rgb_patch_embed(front_center_image)
if self.use_view_embed:
front_center_image_token = (
front_center_image_token
+ self.view_embed[:, :, 3:4, :]
+ self.position_encoding(front_center_image_token)
)
else:
front_center_image_token = (
front_center_image_token
+ self.position_encoding(front_center_image_token)
)
front_center_image_token = front_center_image_token.flatten(2).permute(
2, 0, 1
)
front_center_image_token_global = (
front_center_image_token_global
+ self.view_embed[:, :, 3, :]
+ self.global_embed[:, :, 3:4]
)
front_center_image_token_global = front_center_image_token_global.permute(
2, 0, 1
)
features.extend([front_center_image_token, front_center_image_token_global])
if self.with_lidar:
lidar_token, lidar_token_global = self.lidar_patch_embed(lidar)
if self.use_view_embed:
lidar_token = (
lidar_token
+ self.view_embed[:, :, 4:5, :]
+ self.position_encoding(lidar_token)
)
else:
lidar_token = lidar_token + self.position_encoding(lidar_token)
lidar_token = lidar_token.flatten(2).permute(2, 0, 1)
lidar_token_global = (
lidar_token_global
+ self.view_embed[:, :, 4, :]
+ self.global_embed[:, :, 4:5]
)
lidar_token_global = lidar_token_global.permute(2, 0, 1)
features.extend([lidar_token, lidar_token_global])
features = torch.cat(features, 0)
return features
def forward(self, x):
front_image = x["rgb"]
left_image = x["rgb_left"]
right_image = x["rgb_right"]
front_center_image = x["rgb_center"]
measurements = x["measurements"]
target_point = x["target_point"]
lidar = x["lidar"]
if self.direct_concat:
img_size = front_image.shape[-1]
left_image = torch.nn.functional.interpolate(
left_image, size=(img_size, img_size)
)
right_image = torch.nn.functional.interpolate(
right_image, size=(img_size, img_size)
)
front_center_image = torch.nn.functional.interpolate(
front_center_image, size=(img_size, img_size)
)
front_image = torch.cat(
[front_image, left_image, right_image, front_center_image], dim=1
)
features = self.forward_features(
front_image,
left_image,
right_image,
front_center_image,
lidar,
measurements,
)
bs = front_image.shape[0]
if self.end2end:
tgt = self.query_pos_embed.repeat(bs, 1, 1)
else:
tgt = self.position_encoding(
torch.ones((bs, 1, 20, 20), device=x["rgb"].device)
)
tgt = tgt.flatten(2)
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2)
tgt = tgt.permute(2, 0, 1)
memory = self.encoder(features, mask=self.attn_mask)
hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0]
hs = hs.permute(1, 0, 2) # Batchsize , N, C
if self.end2end:
waypoints = self.waypoints_generator(hs, target_point)
return waypoints
if self.waypoints_pred_head != "heatmap":
traffic_feature = hs[:, :400]
is_junction_feature = hs[:, 400]
traffic_light_state_feature = hs[:, 400]
stop_sign_feature = hs[:, 400]
waypoints_feature = hs[:, 401:411]
else:
traffic_feature = hs[:, :400]
is_junction_feature = hs[:, 400]
traffic_light_state_feature = hs[:, 400]
stop_sign_feature = hs[:, 400]
waypoints_feature = hs[:, 401:405]
if self.waypoints_pred_head == "heatmap":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
elif self.waypoints_pred_head == "gru":
waypoints = self.waypoints_generator(waypoints_feature, target_point)
elif self.waypoints_pred_head == "gru-command":
waypoints = self.waypoints_generator(waypoints_feature, target_point, measurements)
elif self.waypoints_pred_head == "linear":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
elif self.waypoints_pred_head == "linear-sum":
waypoints = self.waypoints_generator(waypoints_feature, measurements)
is_junction = self.junction_pred_head(is_junction_feature)
traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
stop_sign = self.stop_sign_head(stop_sign_feature)
velocity = measurements[:, 6:7].unsqueeze(-1)
velocity = velocity.repeat(1, 400, 32)
traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
traffic = self.traffic_pred_head(traffic_feature_with_vel)
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature