Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,450 Bytes
6073e55 23fdbc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
import argparse
import copy
import nncore
import torch
from videomind.constants import GROUNDER_PROMPT
from videomind.dataset.hybrid import DATASETS
from videomind.dataset.utils import process_vision_info
from videomind.model.builder import build_model
from videomind.utils.io import get_duration
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset')
parser.add_argument('--pred_path')
parser.add_argument('--model_gnd_path')
parser.add_argument('--split', default='test', choices=['train', 'valid', 'test'])
parser.add_argument('--num_threads', type=int, default=1)
parser.add_argument('--device', default='auto')
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--index', type=int, default=0)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.chunk > 1:
pred_path = nncore.join(args.pred_path, f'output_{args.index}.jsonl')
else:
pred_path = nncore.join(args.pred_path, 'output.jsonl')
print(f'Dataset: {args.dataset}({args.split}) Chunk: {args.chunk} Index: {args.index} Output Path: {pred_path}')
model, processor = build_model(args.model_gnd_path, device=args.device)
device = next(model.parameters()).device
annos = DATASETS.get(args.dataset).load_annos(split=args.split)
annos = [annos[i::args.chunk] for i in range(args.chunk)][args.index]
dumps = []
for i in nncore.ProgressBar(range(len(annos))):
anno = copy.deepcopy(annos[i])
dump = dict()
video_path, query, duration, span = anno['video_path'], anno['query'], anno.get('duration'), anno.get('span')
if duration is None:
duration = get_duration(video_path, num_threads=args.num_threads)
print()
print(video_path)
print(duration)
print(query)
messages = [{
'role':
'user',
'content': [{
'type': 'video',
'video': video_path,
'num_threads': args.num_threads,
'min_pixels': 36 * 28 * 28,
'max_pixels': 64 * 28 * 28,
'max_frames': 150,
'fps': 1.0
}, {
'type': 'text',
'text': GROUNDER_PROMPT.format(query)
}]
}]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(text)
images, videos = process_vision_info(messages)
data = processor(text=[text], images=images, videos=videos, return_tensors='pt')
data = data.to(device)
output_ids = model.generate(
**data,
do_sample=False,
temperature=None,
top_p=None,
top_k=None,
repetition_penalty=None,
max_new_tokens=256)
assert data.input_ids.size(0) == output_ids.size(0) == 1
output_ids = output_ids[0, data.input_ids.size(1):]
if output_ids[-1] == processor.tokenizer.eos_token_id:
output_ids = output_ids[:-1]
response = processor.decode(output_ids, clean_up_tokenization_spaces=False)
print(response)
grounder_success = len(model.reg) > 0
if grounder_success:
# 1. extract timestamps and confidences
blob = model.reg[0].cpu().float()
pred, conf = blob[:, :2] * duration, blob[:, 2:]
print(pred[0], span, duration)
# 2. clamp timestamps
pred = pred.clamp(min=0, max=duration)
# 3. round timestamps to units
unit = getattr(DATASETS.get(args.dataset), 'UNIT', 0.001)
pred = torch.round(pred / unit).long() * unit
# 4. sort timestamps
inds = (pred[:, 1] - pred[:, 0] < 0).nonzero()[:, 0]
pred[inds] = pred[inds].roll(1)
# 5. merge timestamps back with confidences
pred = torch.cat((pred, conf), dim=1)
else:
print('WARNING: Failed to parse grounder response')
pred = torch.Tensor([[0, duration, 1]])
print(pred[0], span, duration)
dump['vid'] = anno['vid']
dump['qid'] = anno['qid']
dump['pred_relevant_windows'] = pred.tolist()
dumps.append(dump)
nncore.dump(dumps, pred_path)
|