yeliudev's picture
Update videomind/model/model.py
655ce8d verified
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
import warnings
import nncore
import torch
import torch.nn as nn
import torch.nn.functional as F
from nncore.nn import ModuleList, PositionalEncoding, Sequential, TransformerEncoderLayer, xavier_init_
from nncore.ops import temporal_iou
from transformers import AutoConfig, AutoModel, Qwen2VLConfig, Qwen2VLForConditionalGeneration, Qwen2VLModel
from transformers.activations import ACT2CLS, ACT2FN
from transformers.models.auto.modeling_auto import MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VisionTransformerPretrainedModel
from .blocks import ConvHead, ConvPyramid, LearnableEmbedding, Scale
from .generator import PointGenerator
from .loss import BundleLoss
def cache_state_hook(module, args):
module.state = args[0]
class AgentQwen2VLConfig(Qwen2VLConfig):
model_type = 'agent_qwen2_vl'
class AgentQwen2VisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel):
def __init__(self, config):
super().__init__(config)
self.gradient_checkpointing = False
# add support for gradient checkpointing
# https://github.com/huggingface/transformers/pull/34724
def forward(self, hidden_states, grid_thw):
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for blk in self.blocks:
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens,
rotary_pos_emb)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
return self.merger(hidden_states)
class AgentQwen2VLModel(Qwen2VLModel):
config_class = AgentQwen2VLConfig
def __init__(self, config):
super().__init__(config)
self.norm.register_forward_pre_hook(cache_state_hook)
def forward(self, input_ids=None, inputs_embeds=None, **kwargs):
# ensure gradient tracking (in case that embed_tokens has been frozen)
assert input_ids is None and inputs_embeds is not None
if self.training and not inputs_embeds.requires_grad:
inputs_embeds.requires_grad = True
return super().forward(input_ids=input_ids, inputs_embeds=inputs_embeds, **kwargs)
class AgentQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration):
config_class = AgentQwen2VLConfig
def __init__(self, config):
super().__init__(config)
self.visual = AgentQwen2VisionTransformerPretrainedModel._from_config(config.vision_config)
self.model = AgentQwen2VLModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_deltas = None
if self.config.role in ('all_in_one', 'grounder'):
hidden_size, hidden_act = self.config.hidden_size, self.config.hidden_act
self.dims = 256
self.vis_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
self.reg_proj = Sequential(nn.LayerNorm(hidden_size), nn.Linear(hidden_size, self.dims))
self.vis_norm = nn.LayerNorm(self.dims)
self.vis_fuse = ModuleList(
TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]),
TransformerEncoderLayer(self.dims, act_cfg=ACT2FN[hidden_act]))
self.vis_pos = PositionalEncoding(self.dims, normalize=True, learnable=False)
self.vis_emb = LearnableEmbedding(self.dims)
self.reg_emb = LearnableEmbedding(self.dims)
self.strides = (1, 2, 4, 8)
self.vis_pad_length = self.strides[-1]
self.pyramid = ConvPyramid(self.dims, self.strides, act_cls=ACT2CLS[hidden_act])
self.class_head = ConvHead(self.dims, 1, act_cls=ACT2CLS[hidden_act])
self.coord_head = ConvHead(self.dims, 2, act_cls=ACT2CLS[hidden_act])
self.generator = PointGenerator(self.strides, 1024)
self.coef = Scale(self.strides)
self.bundle_loss = BundleLoss(
sample_radius=1.5,
loss_cls=dict(type='FocalLoss', reduction='none', loss_weight=5.0),
loss_reg=dict(type='L1Loss', reduction='none', loss_weight=1.0),
loss_sal=dict(type='SampledNCELoss', direction='row', loss_weight=0.05))
self.post_init()
def reset_conv_parameters(self):
for s in ('pyramid', 'class_head', 'coord_head'):
b = getattr(self, s, None)
if b is None:
continue
for n, m in b.named_modules():
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
print(f'Reset parameters of {b.__class__.__name__} {n} ({m.__class__.__name__})')
xavier_init_(m, distribution='uniform')
def forward(self,
input_ids=None,
attention_mask=None,
position_ids=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
rope_deltas=None,
timestamps=None,
saliency=None,
pos_clip=None):
mode = 'training' if self.training else 'caching' if (
past_key_values is None or len(past_key_values) == 0) else 'generating'
# https://github.com/huggingface/transformers/pull/33487
if position_ids is None and input_ids is not None:
position_ids, _ = self.get_rope_index(input_ids, image_grid_thw, video_grid_thw, attention_mask)
if mode in ('training', 'caching'):
vision_s_inds = torch.nonzero(input_ids == self.config.vision_start_token_id).tolist()
vision_e_inds = torch.nonzero(input_ids == self.config.vision_end_token_id).tolist()
assert len(vision_s_inds) == len(vision_e_inds)
self.cache_vision_inds = [[] for _ in range(input_ids.size(0))]
for i in range(len(vision_s_inds)):
assert vision_s_inds[i][0] == vision_e_inds[i][0]
self.cache_vision_inds[vision_s_inds[i][0]].append([vision_s_inds[i][1] + 1, vision_e_inds[i][1]])
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=not self.training,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
rope_deltas=rope_deltas)
if mode == 'caching':
self.cache_norm_state = self.model.norm.state
self.reg = []
self.sal = []
if mode == 'training' and timestamps is not None:
loss_regs, avg_factors = [], []
shift_labels = labels[..., 1:].contiguous()
for batch_idx, (vision_inds, ts) in enumerate(zip(self.cache_vision_inds, timestamps)):
# only consider the first video
s, e = vision_inds[0]
# spatial merge size set to 2
window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
assert video_grid_thw[0][0] * window == e - s
inds = torch.where(shift_labels[batch_idx] == self.config.reg_token_id)[0]
reg_tokens = self.reg_proj(self.model.norm.state[batch_idx, inds, None])
# reg_tokens: num_reg_tokens * 1 * channel
vis_tokens = self.model.norm.state[batch_idx, None, s:e]
vis_tokens = vis_tokens.transpose(-1, -2)
vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
vis_tokens = vis_tokens.transpose(-1, -2)
vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
# vis_tokens: num_reg_tokens * num_frames * channel
vis_tokens = self.vis_emb(vis_tokens)
reg_tokens = self.reg_emb(reg_tokens)
pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)
joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
collected = [joint_tokens]
for blk in self.vis_fuse:
collected.append(blk(collected[-1]))
collected = collected[1:]
joint_tokens = torch.cat(collected)
joint_tokens = self.vis_norm(joint_tokens)
video_emb = joint_tokens[:, :-1]
# video_emb: num_reg_tokens * num_frames * channel
query_emb = joint_tokens[:, -1:]
# query_emb: num_reg_tokens * 1 * channel
b, t, c = video_emb.size()
video_msk = video_emb.new_ones(b, t)
if t < self.vis_pad_length:
emb_pad = video_emb.new_zeros(b, self.vis_pad_length - t, c)
msk_pad = video_msk.new_zeros(b, self.vis_pad_length - t)
pymid_emb = torch.cat((video_emb, emb_pad), dim=1)
pymid_msk = torch.cat((video_msk, msk_pad), dim=1)
else:
pymid_emb, pymid_msk = video_emb, video_msk
pymid, pymid_msk = self.pyramid(pymid_emb, pymid_msk, return_mask=True)
if not len(pymid) == len(pymid_msk) == len(self.strides):
warnings.warn(f'pyramid size mismatch: {len(pymid)} {len(pymid_msk)} {len(self.strides)}')
point = self.generator(pymid)
out_class = [self.class_head(e) for e in pymid]
out_class = torch.cat(out_class, dim=1)
out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
out_coord = torch.cat(out_coord, dim=1)
data = dict(
point=point,
video_emb=video_emb,
query_emb=query_emb,
video_msk=video_msk,
pymid_msk=pymid_msk,
out_class=out_class,
out_coord=out_coord,
boundary=point.new_tensor(ts),
saliency=saliency[batch_idx].unsqueeze(0),
pos_clip=pos_clip[batch_idx].unsqueeze(0))
losses = self.bundle_loss(data, dict())
# print({k: v.item() for k, v in losses.items()})
loss_regs.append(sum(v for v in losses.values()))
avg_factors.append(len(ts))
assert len(loss_regs) in (1, 2) and len(loss_regs) == len(avg_factors)
if len(loss_regs) == 2 and loss_regs[0] > loss_regs[1]:
loss_reg, avg_factor = loss_regs[1], avg_factors[1]
else:
loss_reg, avg_factor = loss_regs[0], avg_factors[0]
if avg_factor > 0:
outputs.loss = outputs.loss + loss_reg / avg_factor
elif mode == 'generating':
logits = outputs.logits[0, -1]
if logits.argmax() == self.config.reg_token_id:
assert self.model.norm.state.size() == (1, 1, self.config.hidden_size)
# only consider the first video
s, e = self.cache_vision_inds[0][0]
# spatial merge size set to 2
window = int(video_grid_thw[0][1] * video_grid_thw[0][2] / 4)
assert video_grid_thw[0][0] * window == e - s
reg_tokens = self.reg_proj(self.model.norm.state)
# reg_tokens: num_reg_tokens * 1 * channel
vis_tokens = self.cache_norm_state[:, s:e]
vis_tokens = vis_tokens.transpose(-1, -2)
vis_tokens = F.avg_pool1d(vis_tokens.float(), window, stride=window).to(vis_tokens.dtype)
vis_tokens = vis_tokens.transpose(-1, -2)
vis_tokens = self.vis_proj(vis_tokens).repeat(reg_tokens.size(0), 1, 1)
# vis_tokens: num_reg_tokens * num_frames * channel
vis_tokens = self.vis_emb(vis_tokens)
reg_tokens = self.reg_emb(reg_tokens)
pe = self.vis_pos(vis_tokens).to(vis_tokens.dtype)
joint_tokens = torch.cat((vis_tokens + pe, reg_tokens), dim=1)
for blk in self.vis_fuse:
joint_tokens = blk(joint_tokens)
joint_tokens = self.vis_norm(joint_tokens)
video_emb = joint_tokens[:, :-1]
# video_emb: num_reg_tokens * num_frames * channel
query_emb = joint_tokens[:, -1:]
# query_emb: num_reg_tokens * 1 * channel
b, t, _ = video_emb.size()
video_msk = video_emb.new_ones(b, t)
pymid = self.pyramid(video_emb, video_msk)
point = self.generator(pymid)
out_class = [self.class_head(e).sigmoid() for e in pymid]
out_class = torch.cat(out_class, dim=1)
out_coord = [self.coef(self.coord_head(e).exp(), i) for i, e in enumerate(pymid)]
out_coord = torch.cat(out_coord, dim=1)
sal = out_class[0]
bnd = out_coord[0]
bnd[:, 0] *= -1
bnd *= point[:, 3, None].repeat(1, 2)
bnd += point[:, 0, None].repeat(1, 2)
bnd /= t
bnd = torch.cat((bnd, sal), dim=-1)
_, inds = bnd[:, -1].sort(descending=True)
bnd = bnd[inds]
# hard coding nms config here
nms_cfg = dict(type='normal', thres=0.75)
assert nms_cfg['type'] in ('normal', 'linear', 'gaussian')
for i in range(bnd.size(0)):
max_idx = bnd[i:, -1].argmax(dim=0)
bnd = nncore.swap_element(bnd, i, max_idx + i)
iou = temporal_iou(bnd[i, None, :-1], bnd[i + 1:, :-1])[0]
if nms_cfg['type'] == 'normal':
bnd[i + 1:, -1][iou >= nms_cfg['thres']] = 0
elif nms_cfg['type'] == 'linear':
bnd[i + 1:, -1] *= 1 - iou
else:
bnd[i + 1:, -1] *= (-iou.pow(2) / nms_cfg['sigma']).exp()
# save top-100 predictions
self.reg.append(bnd[:100])
# save all saliency scores
self.sal.append(sal)
return outputs
# set the patched model to a vision model
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES[AgentQwen2VLConfig.model_type] = 'AgentQwen2VLForConditionalGeneration'
AutoConfig.register(AgentQwen2VLConfig.model_type, AgentQwen2VLConfig)
AutoModel.register(AgentQwen2VLConfig, AgentQwen2VLForConditionalGeneration)