my_interfuser_model / modeling_interfuser.py
mohammed-aljafry's picture
Update modeling_interfuser.py
d1cac43 verified
raw
history blame
20.1 kB
# -*- coding: utf-8 -*-
# This file contains all custom class definitions required to run the Interfuser model.
git clone https://github.com/opendilab/InterFuser.git
pip install timm
import sys
sys.path.append('/content/InterFuser')
import math
import copy
import logging
import sys
from collections import OrderedDict
from functools import partial
from typing import Optional, List
from huggingface_hub import HfApi
import numpy as np
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from InterFuser.modeling_interfuser import InterfuserConfig, InterfuserForHuggingFace
from huggingface_hub import notebook_login
# --- ู‡ุฐู‡ ู‡ูŠ ุงู„ุฃุณุทุฑ ุงู„ู…ู‡ู…ุฉ ุงู„ุชูŠ ูŠุฌุจ ุงู„ุชุฃูƒุฏ ู…ู† ูˆุฌูˆุฏู‡ุง ---
from InterFuser.interfuser.timm.models.layers import StdConv2dSame, StdConv2d, to_2tuple
from InterFuser.interfuser.timm.models.registry import register_model
from InterFuser.interfuser.timm.models.resnet import resnet26d, resnet50d, resnet18d, resnet26, resnet50, resnet101d
# --------------------------------------------------------
from transformers import AutoConfig, AutoModel
import os
# ==============================================================================
# SECTION 1: ALL DEPENDENCY CLASSES FROM THE ORIGINAL CODE
# ==============================================================================
class HybridEmbed(nn.Module):
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
super().__init__()
assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone
if feature_size is None:
with torch.no_grad():
training = backbone.training
if training:
backbone.eval()
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
if isinstance(o, (list, tuple)):
o = o[-1]
feature_size = o.shape[-2:]
feature_dim = o.shape[1]
backbone.train(training)
else:
feature_size = to_2tuple(feature_size)
if hasattr(self.backbone, "feature_info"):
feature_dim = self.backbone.feature_info.channels()[-1]
else:
feature_dim = self.backbone.num_features
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=1, stride=1)
def forward(self, x):
x = self.backbone(x)
if isinstance(x, (list, tuple)):
x = x[-1]
x = self.proj(x)
global_x = torch.mean(x, [2, 3], keepdim=False)[:, :, None]
return x, global_x
class PositionEmbeddingSine(nn.Module):
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
bs, _, h, w = x.shape
not_mask = torch.ones((bs, h, w), device=x.device)
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, src, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None):
output = src
for layer in self.layers:
output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=nn.ReLU(), normalize_before=False):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), key=self.with_pos_embed(memory, pos), value=memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(self, tgt, memory, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, pos: Optional[Tensor] = None, query_pos: Optional[Tensor] = None):
output = tgt
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, pos=pos, query_pos=query_pos)
if self.norm is not None:
output = self.norm(output)
return output.unsqueeze(0)
class GRUWaypointsPredictor(nn.Module):
def __init__(self, input_dim, waypoints=10):
super().__init__()
self.gru = torch.nn.GRU(input_size=input_dim, hidden_size=64, batch_first=True)
self.encoder = nn.Linear(2, 64)
self.decoder = nn.Linear(64, 2)
self.waypoints = waypoints
def forward(self, x, target_point):
bs = x.shape[0]
z = self.encoder(target_point).unsqueeze(0)
output, _ = self.gru(x, z)
output = output.reshape(bs * self.waypoints, -1)
output = self.decoder(output).reshape(bs, self.waypoints, 2)
output = torch.cumsum(output, 1)
return output
# ... (Add other dependency classes like SpatialSoftmax, MultiPath_Generator, etc. if needed by other configs)
# --- The ORIGINAL Interfuser Model Class ---
class Interfuser(nn.Module):
def __init__(self, img_size=224,
multi_view_img_size=112,
patch_size=8, in_chans=3,
embed_dim=768,
enc_depth=6,
dec_depth=6,
dim_feedforward=2048,
normalize_before=False,
rgb_backbone_name="r26",
lidar_backbone_name="r26",
num_heads=8, norm_layer=None,
dropout=0.1, end2end=False,
direct_concat=True,
separate_view_attention=False,
separate_all_attention=False,
act_layer=None,
weight_init="",
freeze_num=-1,
with_lidar=False,
with_right_left_sensors=True,
with_center_sensor=False,
traffic_pred_head_type="det",
waypoints_pred_head="heatmap",
reverse_pos=True,
use_different_backbone=False,
use_view_embed=True,
use_mmad_pretrain=None):
super().__init__()
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.waypoints_pred_head = waypoints_pred_head
self.with_lidar = with_lidar
self.with_right_left_sensors = with_right_left_sensors
self.attn_mask = None # Simplified
if use_different_backbone:
if rgb_backbone_name == "r50": self.rgb_backbone = resnet50d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
if rgb_backbone_name == "r26": self.rgb_backbone = resnet26d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
if lidar_backbone_name == "r18": self.lidar_backbone = resnet18d(pretrained=False, in_chans=3, features_only=True, out_indices=[4])
rgb_embed_layer = partial(HybridEmbed, backbone=self.rgb_backbone)
lidar_embed_layer = partial(HybridEmbed, backbone=self.lidar_backbone)
self.rgb_patch_embed = rgb_embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
self.lidar_patch_embed = lidar_embed_layer(img_size=img_size, patch_size=patch_size, in_chans=3, embed_dim=embed_dim)
else: raise NotImplementedError("Only use_different_backbone=True supported in this wrapper")
self.global_embed = nn.Parameter(torch.zeros(1, embed_dim, 5))
self.view_embed = nn.Parameter(torch.zeros(1, embed_dim, 5, 1))
self.query_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, 11))
self.query_embed = nn.Parameter(torch.zeros(400 + 11, 1, embed_dim))
if self.waypoints_pred_head == "gru": self.waypoints_generator = GRUWaypointsPredictor(embed_dim)
else: raise NotImplementedError("Only GRU waypoints head supported in this wrapper")
self.junction_pred_head = nn.Linear(embed_dim, 2)
self.traffic_light_pred_head = nn.Linear(embed_dim, 2)
self.stop_sign_head = nn.Linear(embed_dim, 2)
self.traffic_pred_head = nn.Sequential(*[nn.Linear(embed_dim + 32, 64), nn.ReLU(), nn.Linear(64, 7), nn.Sigmoid()])
self.position_encoding = PositionEmbeddingSine(embed_dim // 2, normalize=True)
encoder_layer = TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before)
self.encoder = TransformerEncoder(encoder_layer, enc_depth, None)
decoder_layer = TransformerDecoderLayer(embed_dim, num_heads, dim_feedforward, dropout, act_layer, normalize_before)
decoder_norm = nn.LayerNorm(embed_dim)
self.decoder = TransformerDecoder(decoder_layer, dec_depth, decoder_norm, return_intermediate=False)
def forward_features(self, front_image, left_image, right_image, front_center_image, lidar, measurements):
features = []
front_image_token, front_image_token_global = self.rgb_patch_embed(front_image)
front_image_token = (front_image_token + self.position_encoding(front_image_token))
front_image_token = front_image_token.flatten(2).permute(2, 0, 1)
front_image_token_global = (front_image_token_global + self.global_embed[:, :, 0:1])
front_image_token_global = front_image_token_global.permute(2, 0, 1)
features.extend([front_image_token, front_image_token_global])
left_image_token, left_image_token_global = self.rgb_patch_embed(left_image)
left_image_token = (left_image_token + self.position_encoding(left_image_token)).flatten(2).permute(2, 0, 1)
left_image_token_global = (left_image_token_global + self.global_embed[:, :, 1:2]).permute(2, 0, 1)
right_image_token, right_image_token_global = self.rgb_patch_embed(right_image)
right_image_token = (right_image_token + self.position_encoding(right_image_token)).flatten(2).permute(2, 0, 1)
right_image_token_global = (right_image_token_global + self.global_embed[:, :, 2:3]).permute(2, 0, 1)
features.extend([left_image_token, left_image_token_global, right_image_token, right_image_token_global])
return torch.cat(features, 0)
def forward(self, x):
front_image, left_image, right_image = x["rgb"], x["rgb_left"], x["rgb_right"]
measurements, target_point = x["measurements"], x["target_point"]
features = self.forward_features(front_image, left_image, right_image, x["rgb_center"], x["lidar"], measurements)
bs = front_image.shape[0]
tgt = self.position_encoding(torch.ones((bs, 1, 20, 20), device=x["rgb"].device)).flatten(2)
tgt = torch.cat([tgt, self.query_pos_embed.repeat(bs, 1, 1)], 2).permute(2, 0, 1)
memory = self.encoder(features, mask=self.attn_mask)
hs = self.decoder(self.query_embed.repeat(1, bs, 1), memory, query_pos=tgt)[0].permute(1, 0, 2)
traffic_feature = hs[:, :400]
waypoints_feature = hs[:, 401:411]
is_junction_feature = hs[:, 400]
traffic_light_state_feature, stop_sign_feature = hs[:, 400], hs[:, 400]
waypoints = self.waypoints_generator(waypoints_feature, target_point)
is_junction = self.junction_pred_head(is_junction_feature)
traffic_light_state = self.traffic_light_pred_head(traffic_light_state_feature)
stop_sign = self.stop_sign_head(stop_sign_feature)
velocity = measurements[:, 6:7].unsqueeze(-1).repeat(1, 400, 32)
traffic_feature_with_vel = torch.cat([traffic_feature, velocity], dim=2)
traffic = self.traffic_pred_head(traffic_feature_with_vel)
return traffic, waypoints, is_junction, traffic_light_state, stop_sign, traffic_feature
# ==============================================================================
# SECTION 2: HUGGING FACE WRAPPER CLASSES
# ==============================================================================
# ==============================================================================
# ุฃุถู ู‡ุฐุง ุงู„ูƒูˆุฏ ููŠ ู†ู‡ุงูŠุฉ ุฎู„ูŠุฉ ุชุนุฑูŠู ุงู„ู†ู…ูˆุฐุฌ ุงู„ุฃุตู„ูŠ
# ==============================================================================
print("
--- Defining Hugging Face compatible wrapper classes ---")
# --- 2. ูุฆุฉ ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุชูˆุงูู‚ุฉ (HF-Compatible Model Class) ---
class InterfuserConfig(PretrainedConfig):
model_type = "interfuser"
def __init__(
self,
embed_dim=256,
enc_depth=6,
dec_depth=6,
num_heads=8,
dim_feedforward=2048,
rgb_backbone_name="r50",
lidar_backbone_name="r18",
waypoints_pred_head="gru",
use_different_backbone=True,
**kwargs
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.enc_depth = enc_depth
self.dec_depth = dec_depth
self.num_heads = num_heads
self.dim_feedforward = dim_feedforward
self.rgb_backbone_name = rgb_backbone_name
self.lidar_backbone_name = lidar_backbone_name
self.waypoints_pred_head = waypoints_pred_head
self.use_different_backbone = use_different_backbone
# Add the architectures key for auto-mapping
self.architectures = ["InterfuserForHuggingFace"]
# --- 2. ูุฆุฉ ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุชูˆุงูู‚ุฉ (HF-Compatible Model Class) ---
# ู‡ุฐู‡ ู‡ูŠ ุงู„ู†ุณุฎุฉ ุงู„ุฌุฏูŠุฏุฉ ู…ู† ู†ู…ูˆุฐุฌูƒ ุงู„ุชูŠ ุชุฑุซ ู…ู† PreTrainedModel
class InterfuserForHuggingFace(PreTrainedModel):
config_class = InterfuserConfig # Link to the config class
def __init__(self, config: InterfuserConfig):
super().__init__(config)
self.config = config
# We instantiate the original Interfuser model inside our wrapper
# The parameters are taken from our config object.
# This requires the original 'Interfuser' class to be defined in the notebook.
self.interfuser_model = Interfuser(
in_chans=self.config.in_chans, # ู‡ู†ุง ุชูู…ุฑู‘ุฑ ุงู„ู‚ูŠู…ู‡
embed_dim=self.config.embed_dim,
enc_depth=self.config.enc_depth,
dec_depth=self.config.dec_depth,
num_heads=self.config.num_heads,
dim_feedforward=self.config.dim_feedforward,
rgb_backbone_name=self.config.rgb_backbone_name,
lidar_backbone_name=self.config.lidar_backbone_name,
waypoints_pred_head=self.config.waypoints_pred_head,
use_different_backbone=self.config.use_different_backbone
)
def forward(self, rgb, rgb_left, rgb_right, rgb_center, lidar, measurements, target_point, **kwargs):
# The original model expects a dictionary, so we create one.
inputs_dict = {
'rgb': rgb,
'rgb_left': rgb_left,
'rgb_right': rgb_right,
'rgb_center': rgb_center,
'lidar': lidar,
'measurements': measurements,
'target_point': target_point
}
# Call the forward method of the original model
# The output is already a tuple, which is what HF expects.
return self.interfuser_model.forward(inputs_dict)