File size: 2,827 Bytes
9d0a4ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pdb
import pprint
from tqdm import tqdm, trange
import numpy as np
import os
from collections import OrderedDict, defaultdict
from utils.basic_utils import AverageMeter

import torch
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from main.config import TestOptions, setup_model
from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
from eval.eval import eval_submission
from eval.postprocessing import PostProcessorDETR
from utils.basic_utils import save_jsonl, save_json
from utils.temporal_nms import temporal_nms
from utils.span_utils import span_cxw_to_xx
from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array

import logging
import importlib

logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
                    datefmt="%Y-%m-%d %H:%M:%S",
                    level=logging.INFO)

def load_model():
    logger.info("Setup config, data and model...")
    opt = TestOptions().parse()
    # pdb.set_trace()
    cudnn.benchmark = True
    cudnn.deterministic = False

    model, criterion, _, _ = setup_model(opt)
    return model

def load_data(save_dir):
  vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
  txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)

  vid = torch.from_numpy(l2_normalize_np_array(vid))
  txt = torch.from_numpy(l2_normalize_np_array(txt))
  clip_len = 2
  ctx_l = vid.shape[0]

  timestamp =  ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)

  if True:
      tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
      tef_ed = tef_st + 1.0 / ctx_l
      tef = torch.stack([tef_st, tef_ed], dim=1)  # (Lv, 2)
      vid = torch.cat([vid, tef], dim=1)  # (Lv, Dv+2)

  src_vid = vid.unsqueeze(0).cuda()
  src_txt = txt.unsqueeze(0).cuda()
  src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
  src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()

  return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l

if __name__ == '__main__':
    clip_len = 2
    save_dir = '/data/home/qinghonglin/univtg/demo/tmp'
  
    model = load_model()
    src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
    with torch.no_grad():
      output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)

    pred_logits = output['pred_logits'][0].cpu()
    pred_spans = output['pred_spans'][0].cpu()
    pred_saliency = output['saliency_scores'].cpu()

    pdb.set_trace()
    top1 = (pred_spans + timestamp)[torch.argmax(pred_logits)] * ctx_l * clip_len
    print(top1)
    print(pred_saliency.argmax()*clip_len)