VideoMind-2B / videomind /eval /infer_qvhighlights.py
yeliudev's picture
Upload folder using huggingface_hub
6073e55 verified
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
import argparse
import copy
import nncore
import torch
from videomind.constants import GROUNDER_PROMPT
from videomind.dataset.hybrid import DATASETS
from videomind.dataset.utils import process_vision_info
from videomind.model.builder import build_model
from videomind.utils.io import get_duration
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset')
parser.add_argument('--pred_path')
parser.add_argument('--model_gnd_path')
parser.add_argument('--split', default='test', choices=['train', 'valid', 'test'])
parser.add_argument('--num_threads', type=int, default=1)
parser.add_argument('--device', default='auto')
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--index', type=int, default=0)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.chunk > 1:
pred_path = nncore.join(args.pred_path, f'output_{args.index}.jsonl')
else:
pred_path = nncore.join(args.pred_path, 'output.jsonl')
print(f'Dataset: {args.dataset}({args.split}) Chunk: {args.chunk} Index: {args.index} Output Path: {pred_path}')
model, processor = build_model(args.model_gnd_path, device=args.device)
device = next(model.parameters()).device
annos = DATASETS.get(args.dataset).load_annos(split=args.split)
annos = [annos[i::args.chunk] for i in range(args.chunk)][args.index]
dumps = []
for i in nncore.ProgressBar(range(len(annos))):
anno = copy.deepcopy(annos[i])
dump = dict()
video_path, query, duration, span = anno['video_path'], anno['query'], anno.get('duration'), anno.get('span')
if duration is None:
duration = get_duration(video_path, num_threads=args.num_threads)
print()
print(video_path)
print(duration)
print(query)
messages = [{
'role':
'user',
'content': [{
'type': 'video',
'video': video_path,
'num_threads': args.num_threads,
'min_pixels': 36 * 28 * 28,
'max_pixels': 64 * 28 * 28,
'max_frames': 150,
'fps': 1.0
}, {
'type': 'text',
'text': GROUNDER_PROMPT.format(query)
}]
}]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(text)
images, videos = process_vision_info(messages)
data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
data = data.to(device)
output_ids = model.generate(
**data,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=None,
max_new_tokens=256)
assert data.input_ids.size(0) == output_ids.size(0) == 1
output_ids = output_ids[0, data.input_ids.size(1):]
if output_ids[-1] == processor.tokenizer.eos_token_id:
output_ids = output_ids[:-1]
response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
print(response)
grounder_success = len(model.reg) > 0
if grounder_success:
# 1. extract timestamps and confidences
blob = model.reg[0].cpu().float()
pred, conf = blob[:, :2] * duration, blob[:, 2:]
print(pred[0], span, duration)
# 2. clamp timestamps
pred = pred.clamp(min=0, max=duration)
# 3. round timestamps to units
unit = getattr(DATASETS.get(args.dataset), 'UNIT', 0.001)
pred = torch.round(pred / unit).long() * unit
# 4. sort timestamps
inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
pred[inds] = pred[inds].roll(1)
# 5. merge timestamps back with confidences
pred = torch.cat((pred, conf), dim=1)
else:
print('WARNING: Failed to parse grounder response')
pred = torch.Tensor([[0, duration, 1]])
print(pred[0], span, duration)
dump['vid'] = anno['vid']
dump['qid'] = anno['qid']
dump['pred_relevant_windows'] = pred.tolist()
dumps.append(dump)
nncore.dump(dumps, pred_path)