KevinQHLin commited on
Commit
9d0a4ae
1 Parent(s): eab7b75

Upload 60 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. app.py +236 -0
  3. examples/charades.mp4 +3 -0
  4. examples/ego4d.mp4 +3 -0
  5. examples/youtube.mp4 +3 -0
  6. main/__init__.py +0 -0
  7. main/_train_qfvs.py +293 -0
  8. main/config.py +378 -0
  9. main/config_hl.py +190 -0
  10. main/config_qfvs.json +14 -0
  11. main/dataset.py +1261 -0
  12. main/dataset_qfvs.py +284 -0
  13. main/inference_demo.py +81 -0
  14. main/inference_hl.py +229 -0
  15. main/inference_mr.py +273 -0
  16. main/inference_qfvs.py +342 -0
  17. main/train_hl.py +229 -0
  18. main/train_mr.py +266 -0
  19. main/train_qfvs.py +325 -0
  20. main/train_vlp.py +278 -0
  21. main/train_vlp_ddp.py +288 -0
  22. model/base.py +449 -0
  23. model/base_albef.py +478 -0
  24. model/base_droppath.py +449 -0
  25. model/base_droppath_ablation.py +474 -0
  26. model/base_droppath_qfvs.py +476 -0
  27. model/base_prompt.py +460 -0
  28. model/base_qfvs.py +476 -0
  29. model/matcher.py +107 -0
  30. model/moment_detr.py +462 -0
  31. model/position_encoding.py +126 -0
  32. model/transformer.py +471 -0
  33. model/transformer_encoder.py +159 -0
  34. model/transformer_encoder_droppath.py +194 -0
  35. model/univtg.py +450 -0
  36. model/univtg_ablation.py +474 -0
  37. model/univtg_qfvs.py +476 -0
  38. requirements.txt +355 -0
  39. results/omni/opt.json +111 -0
  40. run_on_video/__init__.py +1 -0
  41. run_on_video/clip/__init__.py +1 -0
  42. run_on_video/clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  43. run_on_video/clip/clip.py +195 -0
  44. run_on_video/clip/model.py +432 -0
  45. run_on_video/clip/simple_tokenizer.py +132 -0
  46. run_on_video/clip_feature_extractor.py +101 -0
  47. run_on_video/data_utils.py +170 -0
  48. run_on_video/preprocessing.py +25 -0
  49. run_on_video/text_extractor.py +36 -0
  50. run_on_video/video_extractor.py +94 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/charades.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ examples/ego4d.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ examples/youtube.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import torch
5
+ import gradio as gr
6
+ import numpy as np
7
+ import argparse
8
+ import subprocess
9
+ from run_on_video import clip, vid2clip, txt2clip
10
+
11
+ parser = argparse.ArgumentParser(description='')
12
+ parser.add_argument('--save_dir', type=str, default='./tmp')
13
+ parser.add_argument('--resume', type=str, default='./results/omni/model_best.ckpt')
14
+ parser.add_argument("--gpu_id", type=int, default=2)
15
+ args = parser.parse_args()
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
17
+
18
+ #################################
19
+ model_version = "ViT-B/32"
20
+ output_feat_size = 512
21
+ clip_len = 2
22
+ overwrite = True
23
+ num_decoding_thread = 4
24
+ half_precision = False
25
+
26
+ clip_model, _ = clip.load(model_version, device=args.gpu_id, jit=False)
27
+
28
+ import logging
29
+ import torch.backends.cudnn as cudnn
30
+ from main.config import TestOptions, setup_model
31
+ from utils.basic_utils import l2_normalize_np_array
32
+
33
+ logger = logging.getLogger(__name__)
34
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
35
+ datefmt="%Y-%m-%d %H:%M:%S",
36
+ level=logging.INFO)
37
+
38
+ def load_model():
39
+ logger.info("Setup config, data and model...")
40
+ opt = TestOptions().parse(args)
41
+ # pdb.set_trace()
42
+ cudnn.benchmark = True
43
+ cudnn.deterministic = False
44
+
45
+ if opt.lr_warmup > 0:
46
+ total_steps = opt.n_epoch
47
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
48
+ opt.lr_warmup = [warmup_steps, total_steps]
49
+
50
+ model, criterion, _, _ = setup_model(opt)
51
+ return model
52
+
53
+ vtg_model = load_model()
54
+
55
+ def convert_to_hms(seconds):
56
+ return time.strftime('%H:%M:%S', time.gmtime(seconds))
57
+
58
+ def load_data(save_dir):
59
+ vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
60
+ txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
61
+
62
+ vid = torch.from_numpy(l2_normalize_np_array(vid))
63
+ txt = torch.from_numpy(l2_normalize_np_array(txt))
64
+ clip_len = 2
65
+ ctx_l = vid.shape[0]
66
+
67
+ timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
68
+
69
+ if True:
70
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
71
+ tef_ed = tef_st + 1.0 / ctx_l
72
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
73
+ vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
74
+
75
+ src_vid = vid.unsqueeze(0).cuda()
76
+ src_txt = txt.unsqueeze(0).cuda()
77
+ src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
78
+ src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
79
+
80
+ return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
81
+
82
+ def forward(model, save_dir, query):
83
+ src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
84
+ src_vid = src_vid.cuda(args.gpu_id)
85
+ src_txt = src_txt.cuda(args.gpu_id)
86
+ src_vid_mask = src_vid_mask.cuda(args.gpu_id)
87
+ src_txt_mask = src_txt_mask.cuda(args.gpu_id)
88
+
89
+ with torch.no_grad():
90
+ output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
91
+
92
+ # prepare the model prediction
93
+ pred_logits = output['pred_logits'][0].cpu()
94
+ pred_spans = output['pred_spans'][0].cpu()
95
+ pred_saliency = output['saliency_scores'].cpu()
96
+
97
+ # prepare the model prediction
98
+ pred_windows = (pred_spans + timestamp) * ctx_l * clip_len
99
+ pred_confidence = pred_logits
100
+
101
+ # grounding
102
+ top1_window = pred_windows[torch.argmax(pred_confidence)].tolist()
103
+ top5_values, top5_indices = torch.topk(pred_confidence.flatten(), k=5)
104
+ top5_windows = pred_windows[top5_indices].tolist()
105
+
106
+ # print(f"The video duration is {convert_to_hms(src_vid.shape[1]*clip_len)}.")
107
+ q_response = f"For query: {query}"
108
+
109
+ mr_res = " - ".join([convert_to_hms(int(i)) for i in top1_window])
110
+ mr_response = f"The Top-1 interval is: {mr_res}"
111
+
112
+ hl_res = convert_to_hms(torch.argmax(pred_saliency) * clip_len)
113
+ hl_response = f"The Top-1 highlight is: {hl_res}"
114
+ return '\n'.join([q_response, mr_response, hl_response])
115
+
116
+ def extract_vid(vid_path, state):
117
+ history = state['messages']
118
+ vid_features = vid2clip(clip_model, vid_path, args.save_dir)
119
+ history.append({"role": "user", "content": "Finish extracting video features."})
120
+ history.append({"role": "system", "content": "Please Enter the text query."})
121
+ chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history),2)]
122
+ return '', chat_messages, state
123
+
124
+ def extract_txt(txt):
125
+ txt_features = txt2clip(clip_model, txt, args.save_dir)
126
+ return
127
+
128
+ def download_video(url, save_dir='./examples', size=768):
129
+ save_path = f'{save_dir}/{url}.mp4'
130
+ cmd = f'yt-dlp -S ext:mp4:m4a --throttled-rate 5M -f "best[width<={size}][height<={size}]" --output {save_path} --merge-output-format mp4 https://www.youtube.com/embed/{url}'
131
+ if not os.path.exists(save_path):
132
+ try:
133
+ subprocess.call(cmd, shell=True)
134
+ except:
135
+ return None
136
+ return save_path
137
+
138
+ def get_empty_state():
139
+ return {"total_tokens": 0, "messages": []}
140
+
141
+ def submit_message(prompt, state):
142
+ history = state['messages']
143
+
144
+ if not prompt:
145
+ return gr.update(value=''), [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)], state
146
+
147
+ prompt_msg = { "role": "user", "content": prompt }
148
+
149
+ try:
150
+ history.append(prompt_msg)
151
+ # answer = vlogger.chat2video(prompt)
152
+ # answer = prompt
153
+ extract_txt(prompt)
154
+ answer = forward(vtg_model, args.save_dir, prompt)
155
+ history.append({"role": "system", "content": answer})
156
+
157
+ except Exception as e:
158
+ history.append(prompt_msg)
159
+ history.append({
160
+ "role": "system",
161
+ "content": f"Error: {e}"
162
+ })
163
+
164
+ chat_messages = [(history[i]['content'], history[i+1]['content']) for i in range(0, len(history)-1, 2)]
165
+ return '', chat_messages, state
166
+
167
+
168
+ def clear_conversation():
169
+ return gr.update(value=None, visible=True), gr.update(value=None, interactive=True), None, gr.update(value=None, visible=True), get_empty_state()
170
+
171
+
172
+ def subvid_fn(vid):
173
+ save_path = download_video(vid)
174
+ return gr.update(value=save_path)
175
+
176
+
177
+ css = """
178
+ #col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
179
+ #video_inp {min-height: 100px}
180
+ #chatbox {min-height: 100px;}
181
+ #header {text-align: center;}
182
+ #hint {font-size: 1.0em; padding: 0.5em; margin: 0;}
183
+ .message { font-size: 1.2em; }
184
+ """
185
+
186
+ with gr.Blocks(css=css) as demo:
187
+
188
+ state = gr.State(get_empty_state())
189
+
190
+
191
+ with gr.Column(elem_id="col-container"):
192
+ gr.Markdown("""## 🤖️ UniVTG: Towards Unified Video-Language Temporal Grounding
193
+ Given a video and text query, return relevant window and highlight.""",
194
+ elem_id="header")
195
+
196
+ with gr.Row():
197
+ with gr.Column():
198
+ video_inp = gr.Video(label="video_input")
199
+ gr.Markdown("👋 **Step1**: Select a video in Examples (bottom) or input youtube video_id in this textbox, *e.g.* *G7zJK6lcbyU* for https://www.youtube.com/watch?v=G7zJK6lcbyU", elem_id="hint")
200
+ with gr.Row():
201
+ video_id = gr.Textbox(value="", placeholder="Youtube video url", show_label=False)
202
+ vidsub_btn = gr.Button("(Optional) Submit Youtube id")
203
+
204
+ with gr.Column():
205
+ vid_ext = gr.Button("Step2: Extract video feature, may takes a while")
206
+ # vlog_outp = gr.Textbox(label="Document output", lines=40)
207
+ total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
208
+
209
+ chatbot = gr.Chatbot(elem_id="chatbox")
210
+ input_message = gr.Textbox(show_label=False, placeholder="Enter text query and press enter", visible=True).style(container=False)
211
+ btn_submit = gr.Button("Step3: Enter your text query")
212
+ btn_clear_conversation = gr.Button("🔃 Clear")
213
+
214
+ examples = gr.Examples(
215
+ examples=[
216
+ ["./examples/youtube.mp4"],
217
+ ["./examples/charades.mp4"],
218
+ ["./examples/ego4d.mp4"],
219
+ ],
220
+ inputs=[video_inp],
221
+ )
222
+
223
+ gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/anzorq/chatgpt-demo?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br></center>''')
224
+
225
+ btn_submit.click(submit_message, [input_message, state], [input_message, chatbot])
226
+ input_message.submit(submit_message, [input_message, state], [input_message, chatbot])
227
+ # btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, vlog_outp, state])
228
+ btn_clear_conversation.click(clear_conversation, [], [input_message, video_inp, chatbot, state])
229
+ vid_ext.click(extract_vid, [video_inp, state], [input_message, chatbot])
230
+ vidsub_btn.click(subvid_fn, [video_id], [video_inp])
231
+
232
+ demo.load(queur=False)
233
+
234
+
235
+ demo.queue(concurrency_count=10)
236
+ demo.launch(height='800px', server_port=2253, debug=True, share=True)
examples/charades.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fa3d1ba99bf28103844e1313cc5543b7c626d87c42a1c18108c2a69479a6d679
3
+ size 1301669
examples/ego4d.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf1271d42415c793e659bebbd48394326cc50e970d44e6fdd0af5dfb4cb4ede4
3
+ size 28306388
examples/youtube.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dd6b483e5346a777b5d6448460c5e30b8fe46aa1133cf6bba94c84dd7262b49
3
+ size 47353846
main/__init__.py ADDED
File without changes
main/_train_qfvs.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import json
5
+ import pprint
6
+ import random
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import h5py
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.backends.cudnn as cudnn
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import sys
20
+ sys.path.append('/data/home/qinghonglin/univtg')
21
+ from main.config import BaseOptions, setup_model
22
+ from main.dataset import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
23
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle
24
+ from utils.model_utils import count_parameters
25
+ from eval.qfvs import calculate_semantic_matching, load_videos_tag
26
+
27
+ import logging
28
+ logger = logging.getLogger(__name__)
29
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ level=logging.INFO)
32
+
33
+ def eval_epoch(model, config, opt):
34
+ model.eval()
35
+ f1_sum = 0; p_sum = 0; r_sum = 0
36
+
37
+ assert len(config['test_videos']) == 1
38
+ video_id = config['test_videos'][0]
39
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
40
+
41
+ feat_type = config['vid_feature']
42
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
43
+ features = torch.tensor(feat['feature'][()]).unsqueeze(0).cuda()
44
+ # pdb.set_trace()
45
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
46
+
47
+ # dim = features.shape[-1]
48
+ # ctx_l = seg_len.sum().cpu()
49
+
50
+ dim = features.shape[-1]
51
+ ctx_l = features.shape[1]
52
+ seg_len = torch.ones(ctx_l)
53
+ features = features.reshape(-1, dim)[:ctx_l]
54
+
55
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
56
+ tef_ed = tef_st + 1.0 / ctx_l
57
+ tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
58
+ features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
59
+
60
+ transfer = {"Cupglass": "Glass",
61
+ "Musicalinstrument": "Instrument",
62
+ "Petsanimal": "Animal"}
63
+
64
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
65
+ evaluation_num=len(files)
66
+ for file in files:
67
+ summaries_GT=[]
68
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
69
+ for line in f.readlines():
70
+ summaries_GT.append(int(line.strip()))
71
+
72
+ concept1, concept2 = file.split('_')[0:2]
73
+
74
+ ##############
75
+ if concept1 in transfer:
76
+ concept1 = transfer[concept1]
77
+ if concept2 in transfer:
78
+ concept2 = transfer[concept2]
79
+ concept1 = embedding[concept1]
80
+ concept2 = embedding[concept2]
81
+
82
+ data = {
83
+ 'features':features,
84
+ 'seg_len': seg_len,
85
+ 'tokens_pad1':torch.from_numpy(concept1),
86
+ 'tokens_pad2':torch.from_numpy(concept2),
87
+ }
88
+
89
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
90
+
91
+ summaries_GT = [x - 1 for x in summaries_GT]
92
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
93
+
94
+
95
+ output_type = 'pred_logits' # only saliency.
96
+ # if opt.f_loss_coef == 0:
97
+ # output_type = 'saliency_scores' # only saliency.
98
+ # elif opt.s_loss_intra_coef == 0:
99
+ # output_type = 'pred_logits' # cls is default.
100
+ # else:
101
+ # output_type = ['pred_logits', 'saliency_scores']
102
+
103
+ # if opt.qfvs_score_multiple > 0:
104
+ # output_type = ['pred_logits', 'saliency_scores']
105
+
106
+ with torch.no_grad():
107
+ if not isinstance(output_type, list):
108
+ score1 = model(**input1)[output_type].squeeze()
109
+ # score1 = score1.masked_select(mask)
110
+ score2 = model(**input2)[output_type].squeeze()
111
+ # score2 = score2.masked_select(mask)
112
+
113
+ score = model(**input_oracle)[output_type].squeeze()
114
+ # score = score.masked_select(mask)
115
+ else:
116
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
117
+ for output_t in output_type:
118
+ # score1 *= model(**input1)[output_t].squeeze() #.masked_select(mask)
119
+ # score2 *= model(**input2)[output_t].squeeze() #.masked_select(mask)
120
+ # score *= model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
121
+ score1 += model(**input1)[output_t].squeeze() #.masked_select(mask)
122
+ score2 += model(**input2)[output_t].squeeze() #.masked_select(mask)
123
+ score += model(**input_oracle)[output_t].squeeze() #.masked_select(mask)
124
+
125
+ score = score
126
+ # score = score + score1 + score2
127
+
128
+ # since video4 features dim is greater than video_shots_tag.
129
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
130
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
131
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
132
+ f1_sum+=f1; r_sum+=r; p_sum+=p
133
+
134
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
135
+ 'R': round(100* r_sum/evaluation_num,2) ,
136
+ 'P': round(100* p_sum/evaluation_num,2) }
137
+
138
+ def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
139
+ model.train()
140
+ criterion.train()
141
+
142
+ # init meters
143
+ time_meters = defaultdict(AverageMeter)
144
+ loss_meters = defaultdict(AverageMeter)
145
+
146
+ timer_dataloading = time.time()
147
+ loss_total = 0
148
+
149
+ # optimizer.zero_grad()
150
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
151
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
152
+ timer_start = time.time()
153
+ model_input1, model_input2, model_input_oracle, \
154
+ model_gt1, model_gt2, model_gt_oracle, \
155
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
156
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
157
+
158
+ timer_start = time.time()
159
+ output1 = model(**model_input1)
160
+ output2 = model(**model_input2)
161
+ output_oracle = model(**model_input_oracle)
162
+
163
+ loss_dict = {}
164
+ loss_dict1 = criterion(output1, model_gt1)
165
+ loss_dict2 = criterion(output2, model_gt2)
166
+ loss_dict3 = criterion(output_oracle, model_gt_oracle)
167
+
168
+ weight_dict = criterion.weight_dict
169
+ for k in loss_dict1.keys():
170
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
171
+
172
+ # print(loss_dict)
173
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
174
+ loss_total += losses.item()
175
+
176
+ time_meters["model_forward_time"].update(time.time() - timer_start)
177
+ timer_start = time.time()
178
+ # optimizer.zero_grad()
179
+ optimizer.zero_grad()
180
+ losses.backward()
181
+ if opt.grad_clip > 0:
182
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
183
+ # if ((batch_idx + 1) % opt.bsz==0) or (batch_idx == len(train_loader)-1):
184
+ # pdb.set_trace()
185
+ # optimizer.step()
186
+ # optimizer.zero_grad()
187
+ optimizer.step()
188
+ time_meters["model_backward_time"].update(time.time() - timer_start)
189
+
190
+ timer_dataloading = time.time()
191
+ return round(loss_total / len(train_loader), 2)
192
+
193
+ # train in single domain.
194
+ def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
195
+ if opt.device.type == "cuda":
196
+ logger.info("CUDA enabled.")
197
+ model.to(opt.device)
198
+
199
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
200
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
201
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
202
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
203
+
204
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
205
+ if opt.start_epoch is None:
206
+ start_epoch = -1 if opt.eval_init else 0
207
+ else:
208
+ start_epoch = opt.start_epoch
209
+
210
+ val_score = eval_epoch(model, config, opt)
211
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
212
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
213
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
214
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
215
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
216
+ if epoch_i > -1:
217
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
218
+ lr_scheduler.step()
219
+ eval_epoch_interval = opt.eval_epoch
220
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
221
+ with torch.no_grad():
222
+ val_score = eval_epoch(model, config, opt)
223
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
224
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
225
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
226
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
227
+
228
+ if prev_best_score['Fscore'] < val_score['F']:
229
+ prev_best_score['Fscore'] = val_score['F']
230
+ prev_best_score['Precision'] = val_score['P']
231
+ prev_best_score['Recall'] = val_score['R']
232
+
233
+ checkpoint = {
234
+ "model": model.state_dict(),
235
+ "optimizer": optimizer.state_dict(),
236
+ "epoch": epoch_i,
237
+ "opt": opt
238
+ }
239
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
240
+ tb_writer.close()
241
+ return prev_best_score
242
+
243
+ def start_training():
244
+ logger.info("Setup config, data and model...")
245
+ opt = BaseOptions().parse()
246
+ set_seed(opt.seed)
247
+
248
+ config = load_json("./main/config_qfvs.json")
249
+
250
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
251
+
252
+ # key -> test video; value -> training videos.
253
+ qfvs_split = {1: [2, 3, 4],
254
+ 2: [1, 3, 4],
255
+ 3: [1, 2, 4],
256
+ 4: [1, 2, 3]}
257
+ # qfvs_split = {
258
+ # 2: [1, 3, 4],
259
+ # 3: [1, 2, 4],
260
+ # }
261
+
262
+ scores_videos = {}
263
+ for test_id, splits in qfvs_split.items():
264
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
265
+ config['train_videos'] = qfvs_split[test_id]
266
+ config['test_videos'] = [test_id]
267
+ train_dataset = DatasetQFVS(config)
268
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
269
+
270
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
271
+ count_parameters(model)
272
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
273
+ scores_videos['V'+str(test_id)] = best_score
274
+
275
+ # save the final results.
276
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
277
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
278
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
279
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
280
+
281
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
282
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
283
+
284
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
285
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
286
+ tb_writer.close()
287
+
288
+ print(scores_videos)
289
+ return
290
+
291
+ if __name__ == '__main__':
292
+ start_training()
293
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
main/config.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import importlib
8
+ from utils.basic_utils import mkdirp, remkdirp, \
9
+ load_json, save_json, make_zipfile, dict_to_markdown
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ level=logging.INFO)
15
+
16
+ class BaseOptions(object):
17
+ saved_option_filename = "opt.json"
18
+ ckpt_filename = "model.ckpt"
19
+ tensorboard_log_dir = "tensorboard_log"
20
+ train_log_filename = "train.log.txt"
21
+ eval_log_filename = "eval.log.txt"
22
+
23
+ def __init__(self):
24
+ self.parser = None
25
+ self.initialized = False
26
+ self.opt = None
27
+
28
+ def initialize(self):
29
+ self.initialized = True
30
+ parser = argparse.ArgumentParser()
31
+ # * Running configs
32
+ parser.add_argument("--dset_type", type=str, choices=["mr", "hl", "vs", "vlp"]) # moment retrieval, highlight detection, and video summarization
33
+ parser.add_argument("--dset_name", type=str, choices=["qvhighlights", "charades", "anet", "tvsum", "youtube", "summe", "ego4d", "qfvs", "video2gif", "coin", "hacs", "vlp", "videocc", "tacos"])
34
+ parser.add_argument("--domain_name", type=str, default=None)
35
+ parser.add_argument("--model_id", type=str, default="moment_detr")
36
+ parser.add_argument("--exp_id", type=str, default="debug", help="id of this run, required at training")
37
+ parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
38
+ parser.add_argument("--gpu_id", type=int, default=0)
39
+ parser.add_argument("--debug", action="store_true",
40
+ help="debug (fast) mode, break all loops, do not load all data into memory.")
41
+ parser.add_argument("--seed", type=int, default=2018, help="random seed")
42
+
43
+ # * DDP
44
+ parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')
45
+
46
+
47
+ parser.add_argument("--eval_split_name", type=str, default="val",
48
+ help="should match keys in video_duration_idx_path, must set for VCMR")
49
+ parser.add_argument("--data_ratio", type=float, default=1.0,
50
+ help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
51
+ "Use small portion for debug purposes. Note this is different from --debug, "
52
+ "which works by breaking the loops, typically they are not used together.")
53
+ parser.add_argument("--results_root", type=str, default="results")
54
+ parser.add_argument("--num_workers", type=int, default=0,
55
+ help="num subprocesses used to load the data, 0: use main process")
56
+ parser.add_argument("--no_pin_memory", action="store_true",
57
+ help="Don't use pin_memory=True for dataloader. "
58
+ "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
59
+
60
+ # * Training configs
61
+ parser.add_argument("--bsz", type=int, default=32, help="mini-batch size")
62
+ parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run")
63
+ parser.add_argument("--max_es_cnt", type=int, default=200,
64
+ help="number of epochs to early stop, use -1 to disable early stop")
65
+ parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
66
+ parser.add_argument("--lr_drop", type=int, default=400, help="drop learning rate to 1/10 every lr_drop epochs")
67
+ parser.add_argument("--lr_gamma", type=float, default=0.1, help="lr reduces the gamma times after the `drop' epoch")
68
+ parser.add_argument("--lr_warmup", type=float, default=-1, help="linear warmup scheme")
69
+ parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
70
+ parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable")
71
+
72
+ # ** Loss coefficients
73
+ # *** boundary branch
74
+ parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'],
75
+ help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.")
76
+ parser.add_argument('--b_loss_coef', default=10, type=float) # boundary regression e.g., l1
77
+ parser.add_argument('--g_loss_coef', default=1, type=float) # giou loss
78
+ # *** foreground branch
79
+ parser.add_argument('--eos_coef', default=0.1, type=float, help="relative classification weight of the no-object class")
80
+ parser.add_argument('--f_loss_coef', default=4, type=float) # cls loss for foreground
81
+ # *** saliency branch
82
+ parser.add_argument("--s_loss_intra_coef", type=float, default=1., help="inter-video (frame-level) saliency loss e.g. momentdetr saliency loss")
83
+ parser.add_argument("--s_loss_inter_coef", type=float, default=0., help="intra-video (sample-level) saliency loss,")
84
+
85
+ # * Eval configs
86
+ parser.add_argument("--main_metric", type=str, default="MR-full-mAP")
87
+ parser.add_argument('--eval_mode', default=None, type=str,
88
+ help="how to integrate foreground and saliency for better prediction")
89
+ parser.add_argument("--eval_bsz", type=int, default=100,
90
+ help="mini-batch size at inference, for query")
91
+ parser.add_argument("--eval_epoch", type=int, default=5,
92
+ help="number of epochs for once inference")
93
+ parser.add_argument("--eval_init", action="store_true", help="evaluate model before training i.e. `epoch=-1'")
94
+ parser.add_argument("--save_interval", type=int, default=50)
95
+
96
+ parser.add_argument("--resume", type=str, default=None,
97
+ help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
98
+ parser.add_argument("--resume_dir", type=str, default=None,
99
+ help="checkpoint path to resume or evaluate, without --resume_all this only load weights")
100
+ parser.add_argument("--resume_all", action="store_true",
101
+ help="if --resume_all, load optimizer/scheduler/epoch as well")
102
+ parser.add_argument("--start_epoch", type=int, default=None,
103
+ help="if None, will be set automatically when using --resume_all")
104
+
105
+ # ** NMS configs
106
+ parser.add_argument("--no_sort_results", action="store_true",
107
+ help="do not sort results, use this for moment query visualization")
108
+ parser.add_argument("--max_before_nms", type=int, default=10)
109
+ parser.add_argument("--max_after_nms", type=int, default=10)
110
+ parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd")
111
+ parser.add_argument("--nms_thd", type=float, default=-1,
112
+ help="additionally use non-maximum suppression "
113
+ "(or non-minimum suppression for distance)"
114
+ "to post-processing the predictions. "
115
+ "-1: do not use nms. [0, 1]")
116
+
117
+ # * Dataset configs
118
+ parser.add_argument("--use_cache", type=int, default=-1, help="Preload features into cache for fast IO")
119
+ parser.add_argument("--max_q_l", type=int, default=75)
120
+ parser.add_argument("--max_v_l", type=int, default=75)
121
+ parser.add_argument("--clip_length", type=float, default=1.0)
122
+ parser.add_argument("--clip_len_list", type=int, nargs='+')
123
+ parser.add_argument("--max_windows", type=int, default=5)
124
+
125
+ parser.add_argument("--add_easy_negative", type=int, default=1)
126
+ parser.add_argument("--easy_negative_only", type=int, default=1)
127
+ parser.add_argument("--round_multiple", type=int, default=1)
128
+
129
+ parser.add_argument("--train_path", type=str, default=None, nargs='+')
130
+ parser.add_argument("--eval_path", type=str, default=None,
131
+ help="Evaluating during training, for Dev set. If None, will only do training, ")
132
+ parser.add_argument("--train_path_list", type=str, nargs='+')
133
+ parser.add_argument("--eval_path_list", type=str, nargs='+')
134
+ parser.add_argument("--feat_root_list", type=str, nargs='+')
135
+
136
+ parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat")
137
+ parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat")
138
+ parser.add_argument("--v_feat_dirs", type=str, nargs="+",
139
+ help="video feature dirs. If more than one, will concat their features. "
140
+ "Note that sub ctx features are also accepted here.")
141
+ parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir")
142
+ parser.add_argument("--v_feat_dim", type=int, help="video feature dim")
143
+ parser.add_argument("--t_feat_dim", type=int, help="text/query feature dim")
144
+ parser.add_argument("--ctx_mode", type=str, default="video_tef")
145
+ parser.add_argument("--v_feat_types", type=str)
146
+ parser.add_argument("--t_feat_type", type=str)
147
+
148
+ # * Model configs
149
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
150
+ help="Type of positional embedding to use on top of the image features")
151
+ parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to vid/txt projector")
152
+ parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss")
153
+
154
+ # ** Transformer
155
+ parser.add_argument('--enc_layers', default=4, type=int,
156
+ help="Number of encoding layers in the transformer")
157
+ parser.add_argument('--sub_enc_layers', default=2, type=int,
158
+ help="Number of encoding layers in the video / text transformer in albef-style.")
159
+ parser.add_argument('--dec_layers', default=2, type=int,
160
+ help="Number of decoding layers in the transformer, N/A for UniVTG")
161
+ parser.add_argument('--dim_feedforward', default=1024, type=int,
162
+ help="Intermediate size of the feedforward layers in the transformer blocks")
163
+ parser.add_argument('--hidden_dim', default=256, type=int,
164
+ help="Size of the embeddings (dimension of the transformer)")
165
+ parser.add_argument('--input_dropout', default=0.5, type=float,
166
+ help="Dropout applied in input")
167
+ parser.add_argument('--dropout', default=0.1, type=float,
168
+ help="Dropout applied in the transformer")
169
+ parser.add_argument('--droppath', default=0.1, type=float,
170
+ help="Droppath applied in the transformer")
171
+ parser.add_argument("--txt_drop_ratio", default=0, type=float,
172
+ help="drop txt_drop_ratio tokens from text input. 0.1=10%")
173
+ parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.")
174
+ parser.add_argument('--nheads', default=8, type=int,
175
+ help="Number of attention heads inside the transformer's attentions")
176
+ parser.add_argument('--num_queries', default=10, type=int,
177
+ help="Number of query slots")
178
+ parser.add_argument('--pre_norm', action='store_true')
179
+
180
+ # ** momentdetr configs e.g. Matcher, saliency margin
181
+ parser.add_argument('--set_cost_span', default=10, type=float,
182
+ help="L1 span coefficient in the matching cost")
183
+ parser.add_argument('--set_cost_giou', default=1, type=float,
184
+ help="giou span coefficient in the matching cost")
185
+ parser.add_argument('--set_cost_class', default=4, type=float,
186
+ help="Class coefficient in the matching cost")
187
+ parser.add_argument("--saliency_margin", type=float, default=0.2)
188
+ parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_true',
189
+ help="Disables auxiliary decoding losses (loss at each layer)")
190
+
191
+ # * Query-Force Video Summarization
192
+ parser.add_argument("--max_segment_num", type=int, default=20)
193
+ parser.add_argument("--max_frame_num", type=int, default=200)
194
+ parser.add_argument("--top_percent", type=float, default=0.02)
195
+
196
+ parser.add_argument("--qfvs_vid_feature", type=str, default='fps1')
197
+ parser.add_argument("--qfvs_txt_feature", type=str, default='query')
198
+ parser.add_argument("--qfvs_split", type=int, default=-1)
199
+
200
+ parser.add_argument("--qfvs_dense_shot", type=int, default=-1)
201
+ parser.add_argument("--qfvs_score_ensemble", type=int, default=-1)
202
+ parser.add_argument("--qfvs_score_gather", type=int, default=-1)
203
+ parser.add_argument("--qfvs_loss_gather", type=int, default=-1)
204
+ self.parser = parser
205
+
206
+ def display_save(self, opt):
207
+ args = vars(opt)
208
+ # Display settings
209
+ print(dict_to_markdown(vars(opt), max_str_len=120))
210
+ # Save settings
211
+ if not isinstance(self, TestOptions):
212
+ option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
213
+ save_json(args, option_file_path, save_pretty=True)
214
+
215
+ def parse(self, args=None):
216
+ if not self.initialized:
217
+ self.initialize()
218
+ opt = self.parser.parse_args()
219
+
220
+ if args is not None:
221
+ args_dict = vars(args)
222
+ opt_dict = vars(opt)
223
+ for key, value in args_dict.items():
224
+ opt_dict[key] = value
225
+ opt = argparse.Namespace(**opt_dict)
226
+ opt.model_dir = os.path.dirname(opt.resume)
227
+ torch.cuda.set_device(opt.gpu_id)
228
+
229
+ if opt.debug:
230
+ opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
231
+ opt.num_workers = 0
232
+
233
+ if isinstance(self, TestOptions):
234
+ # modify model_dir to absolute path
235
+ # opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
236
+ opt.model_dir = os.path.dirname(opt.resume)
237
+ saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
238
+ for arg in saved_options: # use saved options to overwrite all BaseOptions args.
239
+ if arg not in ["results_root", "num_workers", "nms_thd", "debug", "max_before_nms", "max_after_nms"
240
+ "max_pred_l", "min_pred_l", "gpu_id",
241
+ "resume", "resume_all", "no_sort_results",
242
+ "eval_path", "eval_split_name"]:
243
+ # "dset_name", "v_feat_dirs", "t_feat_dir"]:
244
+ setattr(opt, arg, saved_options[arg])
245
+ # opt.no_core_driver = True
246
+ if opt.eval_results_dir is not None:
247
+ opt.results_dir = opt.eval_results_dir
248
+ else:
249
+ if opt.exp_id is None:
250
+ raise ValueError("--exp_id is required for at a training option!")
251
+
252
+ # ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode
253
+
254
+ if 'debug' not in opt.exp_id:
255
+ opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), "-".join([opt.exp_id, opt.v_feat_types, opt.t_feat_type, time.strftime("%Y_%m_%d_%H")]))
256
+ else:
257
+ opt.results_dir = os.path.join(opt.results_root, "-".join([opt.dset_type, opt.dset_name]), opt.exp_id) # debug mode.
258
+
259
+ if int(opt.local_rank) in [0, -1]:
260
+ # mkdirp(opt.results_dir)
261
+ remkdirp(opt.results_dir) # remove dir and remkdir it.
262
+
263
+ # save a copy of current code
264
+ code_dir = os.path.dirname(os.path.realpath(__file__))
265
+ code_zip_filename = os.path.join(opt.results_dir, "code.zip")
266
+ make_zipfile(code_dir, code_zip_filename,
267
+ enclosing_dir="code",
268
+ exclude_dirs_substring="results",
269
+ exclude_dirs=["results", "debug_results", "__pycache__"],
270
+ exclude_extensions=[".pyc", ".ipynb", ".swap"], )
271
+
272
+ if int(opt.local_rank) in [0, -1]:
273
+ self.display_save(opt)
274
+ opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
275
+ opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
276
+ opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
277
+ opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
278
+ # opt.device = torch.device("cuda" if opt.device >= 0 else "cpu")
279
+
280
+ if int(opt.local_rank) in [-1]:
281
+ torch.cuda.set_device(opt.gpu_id)
282
+ opt.pin_memory = not opt.no_pin_memory
283
+
284
+ if opt.local_rank == -1:
285
+ torch.cuda.set_device(opt.gpu_id)
286
+
287
+ opt.use_tef = "tef" in opt.ctx_mode
288
+ opt.use_video = "video" in opt.ctx_mode
289
+ if not opt.use_video:
290
+ opt.v_feat_dim = 0
291
+ if opt.use_tef:
292
+ opt.v_feat_dim += 2
293
+
294
+ self.opt = opt
295
+ return opt
296
+
297
+ class TestOptions(BaseOptions):
298
+ """add additional options for evaluating"""
299
+
300
+ def initialize(self):
301
+ BaseOptions.initialize(self)
302
+ # also need to specify --eval_split_name
303
+ self.parser.add_argument("--eval_id", type=str, help="evaluation id")
304
+ self.parser.add_argument("--eval_results_dir", type=str, default=None,
305
+ help="dir to save results, if not set, fall back to training results_dir")
306
+ self.parser.add_argument("--model_dir", type=str,
307
+ help="dir contains the model file, will be converted to absolute path afterwards")
308
+
309
+ class WarmupStepLR(torch.optim.lr_scheduler.StepLR):
310
+ def __init__(self, optimizer, warmup_steps, step_size, gamma=0.1, last_epoch=-1):
311
+ self.warmup_steps = warmup_steps
312
+ self.step_size = step_size
313
+ self.gamma = gamma
314
+ super(WarmupStepLR, self).__init__(optimizer, step_size, gamma=self.gamma, last_epoch=last_epoch)
315
+ def get_lr(self):
316
+ if not self._get_lr_called_within_step:
317
+ import warnings
318
+ warnings.warn("To get the last learning rate computed by the scheduler, "
319
+ "please use `get_last_lr()`.", DeprecationWarning)
320
+ # e.g. warmup_steps = 10, case: 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21...
321
+ if self.last_epoch == self.warmup_steps or(self.last_epoch % self.step_size != 0 and self.last_epoch > self.warmup_steps):
322
+ return [group['lr'] for group in self.optimizer.param_groups]
323
+ # e.g. warmup_steps = 10, case: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
324
+ elif self.last_epoch < self.warmup_steps:
325
+ return [group['initial_lr'] * float(self.last_epoch + 1) / float(self.warmup_steps) for group in self.optimizer.param_groups]
326
+
327
+
328
+ # e.g. warmup_steps = 10, case: 10, 20, 30, 40...
329
+ return [group['lr'] * self.gamma
330
+ for group in self.optimizer.param_groups]
331
+ def _get_closed_form_lr(self):
332
+ if self.last_epoch <= self.warmup_steps:
333
+ return [base_lr * float(self.last_epoch) / (self.warmup_steps) for base_lr in self.base_lrs]
334
+ else:
335
+ return [base_lr * self.gamma ** ((self.last_epoch - self.warmup_steps)// self.step_size) for base_lr in self.base_lrs]
336
+
337
+ def setup_model(opt):
338
+ """setup model/optimizer/scheduler and load checkpoints when needed"""
339
+ logger.info("setup model/optimizer/scheduler")
340
+
341
+ importer = importlib.import_module('.'.join(['model', opt.model_id]))
342
+ model, criterion = importer.build_model(opt)
343
+
344
+ if int(opt.device) >= 0:
345
+ logger.info("CUDA enabled.")
346
+ model.to(opt.gpu_id)
347
+ criterion.to(opt.gpu_id)
348
+
349
+ param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}]
350
+ optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd)
351
+
352
+ if opt.lr_warmup != -1 and opt.lr_drop > 0:
353
+ lr_scheduler = WarmupStepLR(optimizer, warmup_steps=opt.lr_warmup[0], step_size=opt.lr_drop, gamma=opt.lr_gamma)
354
+
355
+ elif opt.lr_warmup != -1:
356
+ from transformers import get_constant_schedule_with_warmup
357
+ lr_scheduler = get_constant_schedule_with_warmup(optimizer, opt.lr_warmup[0])
358
+
359
+ elif opt.lr_drop > 0:
360
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop, gamma=opt.lr_gamma)
361
+
362
+ if opt.resume is not None:
363
+ logger.info(f"Load checkpoint from {opt.resume}")
364
+ checkpoint = torch.load(opt.resume, map_location="cpu")
365
+
366
+ for key in list(checkpoint["model"].keys()):
367
+ checkpoint["model"][key.replace('module.', '')] = checkpoint["model"].pop(key)
368
+ model.load_state_dict(checkpoint["model"])
369
+
370
+ if opt.resume_all:
371
+ optimizer.load_state_dict(checkpoint['optimizer'])
372
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
373
+ opt.start_epoch = checkpoint['epoch'] + 1
374
+ logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}")
375
+ else:
376
+ logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path")
377
+
378
+ return model, criterion, optimizer, lr_scheduler
main/config_hl.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) THL A29 Limited, a Tencent company. All rights reserved.
2
+
3
+ YOUTUBE_SPLITS = {
4
+ 'dog': {
5
+ 'train': [
6
+ 'BsjTtq337mM', 'eGCD1F74iy8', 'x2Za-t1yHtI', 'iyYiqa0QZXM',
7
+ 'azy9ijU6f9I', 'NNtSZ6cPiwA', 'U9CBalvFfbM', 'AZDkqJaOgJU',
8
+ '-olTgMPAyMI', 'i35F1Ec3Ats', '6bS6-GVLBeM', 'ZGszTEn28v8',
9
+ 'EEb8iSMqwj4', 'p2hYGNkRMCw', '3kbptPDIz4U', 'iLHRqR-M9HQ',
10
+ 'zyooMDuAgCA', 'dOVsQ63N0gg', '7H_qqQvPUzY', 'Z5BEFsaYIS4',
11
+ 'iWO6io44-Fs', 'vVmGisWK0QI', 'L10kN7Btk90', '2yql1mvWbDs',
12
+ 'Iu2nbtr_Uuk', 'NSmOKAauZpM', 'PAhQGoURAro', 'uJ81Us4mBOc',
13
+ '1krGVyfIaOw', 'p9yW6FxsrJ4', 'DLGRJfpGmCQ', '0XTXKe2TOAg',
14
+ 'qpc4OSqeV7I', 'q_PJFuBOk7k', '0Uu53hCnKQ4', '-szRD9kyNug',
15
+ 'rUPxwWmJYpg', 'hseONiKKx_8', 'BLaQcOcDfjo', 'nW5JulWYEc8',
16
+ 'rMvH1SMGwwI', 'l6KlvTJkTgk', 'O8j4U3NjNvs', '8AJTZeEeStk'
17
+ ],
18
+ 'val': [
19
+ 'a2nj7XCo2Rk', '9rP5yF9EC3Y', 'OxSsRZqPfyk', 'bZzP2MieC1c',
20
+ 'PcvdX5OVgfQ', 'p0oxRJD1GUk', 'msjK8nHZHZ0', 'hSRyclcZyGM',
21
+ 'dlH2K9N_jSM', 'OCVXhRG2fEA', 'MkBdHvXPocc', 'yN7h90Y-04g',
22
+ 'PWqLJKZeBC8', '9D_Q8l_ruQk', 'Mp8Pz86J660', '1gjntnYm8NA',
23
+ 'O3XxuutEvoo', 'wf_qlAizlSM', 'fXx44D1sqUw', 'P0MnXh6bnKk',
24
+ 'sTd06idFa0E', 'ppNjl3I3iJs', 'Om5mczkpcVg', 'xZIN_s-qhbU'
25
+ ]
26
+ },
27
+ 'gymnastics': {
28
+ 'train': [
29
+ 'Wfv90YJ2YtA', 'MbD5OIR9yWc', 'fZwCJWkC_Qw', 'AyRI1CioQfY',
30
+ 'xV_5YCdVqSM', '19UO7T32DJI', 'o2gAP2Clg_s', 'ewyfAOrBzjQ',
31
+ 'CMTKpA683Ig', 'aNjphhjTgqs', 'dmJ0Nq4DF2w', '57IQ6EudvGU',
32
+ 'BAlUYtPUsVI', '_UU4XqYVDqE', 'Kq4OhBiQk_E', 'D6nyvx9kEac',
33
+ 'g-m4-zeCisU', '_45vTFtcduE', '9L-Pocc_u70', '0636XaURL-A',
34
+ 'GCabQyaHSMg', 'vUi1Scb35fQ', 'eK-Yuoou_1I', 'kkS7TgNZwJI',
35
+ '2EFkINKg3nA', 'eKvALYDh7RU', 'Hyp3Hpk6dyA', '9rpzf3sgQkw',
36
+ 'kHNAnpewyeo', 'ydQij10qrZM', '41u2V_ZAKto', '6NSWsMKAgEU',
37
+ 'kUs_yUR-C2k', 'bs3ZBcfhvKA'
38
+ ],
39
+ 'val': [
40
+ '2AuigNFEsTM', 'rPsKpHKzUso', 'tzq5cJQ9NQA', 'DyZ0gZ5xmxI',
41
+ 'PEKRfJYYEgU', 'affAIVH9uRA', 'FT7yIi3-tG0', 'T_zWyrVzyvw',
42
+ 'RoiLzMA_ilA', 'nBZiGSccsTg', 'z3cNtOMKK7A', 'EwQ-aMK2sKg',
43
+ 'Rq0BpciuvBM', 's6LNwTThBgs', '-hE9v3izo4c', 'KldEfRhv7H0',
44
+ 'eUyuw2J5FaE', 'E0aRE1_ea8E', 'BU7YlQAOBkM', 'iDJM9j11U-c',
45
+ 'zr5LSPMBpiI', 'NAfBa7lqg2Q', 'eB4Toq9dUWs', 'YPd7RDN5CkE',
46
+ '86YLsw7efDM', 'iQRMMFiYAUw', 'lzEhLAPxZyQ', 'PAjJbT1DRnY'
47
+ ]
48
+ },
49
+ 'parkour': {
50
+ 'train': [
51
+ 'qz1UnnxlWhI', 'MzODICzycHs', '0swXWs9yWA4', 'Nnv22OW_PaI',
52
+ 'LUhZJLY2uKc', 'yZz8z1l3XJU', '3dvjtdMC2ls', 'e27ppPer9XY',
53
+ 'HJNn2WlKFhM', 'j4OxlxnapNI', 'rhABvn7VjSQ', '3PCwXpwYqLs',
54
+ 'LECL1bIpi5w', 'w0ouP79iZWc', 'z6aKQPMJUC0', 'kATlFTwxBVY',
55
+ '3SM6a8eyuVA', 'v-Sfc4COqRQ', '64eu8pwuIUE', '7WKm0XDk3og',
56
+ '2F5Sc0Jgk4g'
57
+ ],
58
+ 'val': [
59
+ 'TFdbCRkVeIA', 'uGLs9atTvNc', 'qlGPuopK3CI', 'ucTkpjZO_o4',
60
+ '4-4BgyGphLQ', '08k4ysX_XJE', '6sMNnWqa_as', 'oT6g0I2Ok9o',
61
+ 'Be4IlnKeBOo', 'yUjJq0kvxcw', 'fLek7GRIxjE'
62
+ ]
63
+ },
64
+ 'skating': {
65
+ 'train': [
66
+ '7owXLUkpoNY', '1OLM0_Jzt5M', 'b1LXb0Sbiy0', '3fGux6-ttlA',
67
+ 'HQvRun80GyA', 'a8M-5nTrll8', 'bA3CxZllhsI', 'AUAsfZtcB4E',
68
+ 'FG57uCJvQLw', 'jXIuv5uFPTI', 'eG-hdYLoS98', '2SdJBl251PU',
69
+ '2PHJqqrGC80', 'EtZkkFhniRw', 'jUiwyguxzIw', 'FL6mXlaF78Q',
70
+ 'BdemklZtYWI', 'ATk_ncI1-BA', '4wiKDfq3X8U', 'BN7GBjVlFTo',
71
+ 'JiMZvMkkbRo', '2DIXYkSnRf4', 'dZ3i-HuhQXM', '7jZydh62m8M'
72
+ ],
73
+ 'val': [
74
+ '2oOe2_Ew6Ao', 'DGcO0QgcXtw', 'ixsKaNplm6o', '7TQbqKWjLcI',
75
+ 'CQZNrEstSag', 'g1WbAIzkw80', '4cyx1VpDjc4', 'BGZaaqFjoRY',
76
+ 'AJ98A2y1dVw', '1n7Afe5AZCM', '8x8ESK5MnR0'
77
+ ]
78
+ },
79
+ 'skiing': {
80
+ 'train': [
81
+ '6Usy87KaF-A', 'DtjKkp_4KDQ', '4Wt7TM2wDxI', 'iKnzSGFwdbc',
82
+ 'nALCc6HPQNs', 'WL4TA--CVcA', 'dFrfsgW1M98', 'x6qmrVojcYc',
83
+ 'pvcmQ9J_BYw', 'S3VEYFAP_pk', 'pU57a3jYMEk', '33TrLdo3ook',
84
+ 'xLhHU8uo2aY', 'fAHBmka6Psc', '9HYzZk5kiJA', 'T0gjqYbeU1g',
85
+ '7o628W-bFy0', 'YKDm_PCa-HM', 'R3DV2zDnNqg', 'NCe9YeXTvHo',
86
+ '5tXxvscmZ-Y', 'thNiPQLbi5w', '1TtJy8cSzqA', 'zDRzOsmwa08',
87
+ 'gCI4gArPjNA', 'uw0i26NHucs', '1giAsZC_ywQ', 'OvgaPTfEnqo',
88
+ 'bFD_p5znoq4', 'uKmqaAvjKgw', '5ivw_sdCTCU', 'iwCSAYGwPq4',
89
+ 'HmmOPntPlRA', 'FHCEyiM-NoY', 'EUSFMmoE_jI', 'igvSxtdsT8w',
90
+ 'zEgMYFiEaX4', '0K2FKccDp9A', 'tdyz6h4ZtYs', 'PO7GEbi2z3c',
91
+ 'mmiu7rRmSAU', 'qL6Kic-CdTo', '0fNCsOY1WGk', 'V3J26hr1ZSE',
92
+ 'GS-qBunN3B4', 'ZLNvg8025Nw', 'puAxGH6aWMY', 'h-SlvHubhs8',
93
+ 'AdovZ4OAS8I', 'UDvA1XMa1m4', 'qdo3d7mR_9s', 'qAinbyORWIw',
94
+ 'v1JpJueAElY', 'TjH29fdjcqI', 'f76B1uucoyo', 'DNPPDcOd5eQ',
95
+ '-GX95udKKm8', 'YRO_RQ3aBgg', '1ptV2E7lm9U', 'qa7dtf1Qcew',
96
+ '_UJTkqYNrpA', 'md14DNKq2_o', 'tpewrb9dDyo', 'yGoWYi_dHLY',
97
+ 'DZ3NRjDHwy8', 'aMFcEuJUqpk', '6fT9KLuE7no', 'lPdQMMAuOZo'
98
+ ],
99
+ 'val': [
100
+ 'SSlv7qJK5zA', '_BYqZjuKpKA', 'ZueaKXReGjU', 'mGST8ZekCZc',
101
+ 'JJSu7Lh9rvs', 'IyoD3G5igY0', 'MXyv-Ut9HRg', 'Z8X9WIojH1U',
102
+ 'vT33-8KUb2Q', 'HW6_sPym938', '9wtXO2lF6hM', 'mRdthCqe6Nk',
103
+ 'RGxiOb9hlS0', 'ruySf5zL7Kw', 'I7wFmP6P7p0', '0AHkDElk3ws',
104
+ 'zqXd4EgUFhE', '91lDbBHUx0w', 'iaHbK6ogafc', 'jRbst8kjWW8',
105
+ 'drHPy6wSZGs', '5VaY6LgIqDs', 'bXq9rRSbI3c', 'hjZLa2DTuqs',
106
+ 'Ka2qcp3jmWo', 'ZnA4-ggkFu8', 'iXdt4v42mbs', '8aWN-0NZErI',
107
+ '09v0HNf81J0', 'YJCR2q-WRhQ', 'RjagI4pAUpw', '_10CbYdTG5M',
108
+ 'lhgmIgzBQxs', '2pstGBM4p0w', 'b53-VPsWom4', 'x-G4r153n6o',
109
+ 'qBbqK5qlVSM', 'XamrS9XyHuQ', 'u_n7jMS1vlw', 'AO6p0jlOd6U',
110
+ 'm-W-lcTkBQ0', 'bMuyPVIlXW8', 'kAAvTAKkIy4', 'U6vnbCurZQA',
111
+ 'dHE8q7sZ70U', 'w7fzLVRPSUc', 'FLYkD7zHuHQ', 'nhOhI24P7dM',
112
+ 'n5q2KhfoiWw', '7Hcyse0h9HE', '6_BPy_VaPSY'
113
+ ]
114
+ },
115
+ 'surfing': {
116
+ 'train': [
117
+ 'Ai9FwQGn5ds', 'hBl0Sm3_auw', 'LMxMeg407Vg', 'D3fk8doVui4',
118
+ 'Y9pxmLg6ti8', 'p_JsivYdbgQ', 'UokX-hcXQeo', 'VYe5QfM5ecE',
119
+ 'I48VJ92ouTQ', 'Tn-ebtUnq6E', 'eWae-nWocPU', '-Yamat_0tbw',
120
+ 'c2Fy-rdXJy4', 'xQ4NAp4vWbI', 'g9kXCIjIjoE', 'A96Jx6gv6_4',
121
+ 'e427qElqqN0', 'tTcA5hiViPo', 'wMdXzj_3aA0', 'fqNzMz1n6uA',
122
+ 'jKVOA7RFCUo', 'TJBJrk9iPPA', '_C8EjMxrS2s', 'yj7abHfZTQQ',
123
+ 'NDcqgpsyWaU', 'UJjwoivaGNo', 'GZ_XS8EnnWo', 'kJUBIcBjUZ0',
124
+ 'lWoLyR7lDAU', 'FilbyF_PGjI', 'fapRkcOe4vE', 't05r50PQqww',
125
+ 'QgStLppe610', '2TY8Q2WXUyk', '9y_ED3DyNhE', 'CGwtinVGkVU',
126
+ 'nOuRhrAMaIw', 'UN4TwjDajtQ', '-FHmVZWWgcE', 'ksx0_BfpsLg',
127
+ 'agOBPDsQrTM', 'XqggBwFOmFU', 'orNzj1J8i-4', '6ZbTCHwt1gk',
128
+ '0un3wh_pQAc', '4u6OURBLZDs', 'us0agAKuvEM', 'mVQYl7Q-TQs',
129
+ 'cB2SdlGHLMQ', 'WK5t4To0zlA', 'NNEuH_juUHI', 'KTU7xfVOat0',
130
+ 'Y1nhbNaY1ZY', 'YlXJnZe575s', 'SH7Ns0ANzJU', '3TbZfeokCkE'
131
+ ],
132
+ 'val': [
133
+ 'o0on6yIXJQE', '4RsZz_8d8Ro', 'p8VUjcZyK70', '0P2PZXUa0Bg',
134
+ 'p2eU5z647Mw', 'mSVxaAJcNJQ', 'bcmXVyFbsRg', 'Eiq8GHi4kEo',
135
+ 'H5FEdJYokO4', 'Mkyp0z_Cgig', 'NB5Ez5kJfMU', 'Xa0y6b6Vm6U',
136
+ 'gVcCGUtpA90', '0-fstXuo_Pw', '-d72e4v9skA', 'lbp6_wCXqvw',
137
+ '9GpZHq1n8ps', 'CefGXyYu_zU', 'SI2JbS48Upg', 'hdklRTNrq0I',
138
+ 'J-P-t6g19SM', 'K0f_DpVOjfA', 'lw_1fEY9QTo', 'uUuYnKLETLw',
139
+ 'HwKv3Xc5MAE', 'wvQ0h5Nwsxc', 'l8ME6z_EWKE', 's9dTu2fcbNg',
140
+ 'GS09SevPYT4', 'YbwdDCzVczU', 'jaCOI_VwIjc', '3Y1Jp1_fFLQ',
141
+ '82OzgxT2tH8', 'IjQhHPlTfdE', 'KzQcJrT91jU', 't05AD0c08zE',
142
+ 'rGxWxX6nYO4', 'QGp0kRzKiAc', 'pK9gDWoOyko', 'Srjd4pe6vck',
143
+ 'twGcxuhCXoU', 'AshLUHPEb8M', '8En3M5CUc2E', '8sTJfTUk1d0',
144
+ 'o-bubyWTw60', 'NctbssxGCtU', 'L09Qo1ql0nM'
145
+ ]
146
+ }
147
+ }
148
+
149
+ TVSUM_SPLITS = {
150
+ 'BK': {
151
+ 'train': ['WxtbjNsCQ8A', 'EE-bNr36nyA', 'oDXZc0tZe04', 'uGu_10sucQo'],
152
+ 'val': ['Se3oxnaPsz0']
153
+ },
154
+ 'BT': {
155
+ 'train': ['eQu1rNs0an0', 'qqR6AEXwxoQ', 'EYqVtI9YWJA', 'iVt07TCkFM0'],
156
+ 'val': ['JgHubY5Vw3Y']
157
+ },
158
+ 'DS': {
159
+ 'train': ['kLxoNp-UchI', 'NyBmCxDoHJU', 'jcoYJXDG9sw', '-esJrBWj2d8'],
160
+ 'val': ['E11zDS9XGzg']
161
+ },
162
+ 'FM': {
163
+ 'train': ['_xMr-HKMfVA', 'byxOvuiIJV0', 'VuWGsYPqAX8', 'xmEERLqJ2kU'],
164
+ 'val': ['JKpqYvAdIsw']
165
+ },
166
+ 'GA': {
167
+ 'train': ['xxdtq8mxegs', 'i3wAGJaaktw', '0tmA_C6XwfM', '3eYKfiOEJNs'],
168
+ 'val': ['Bhxk-O1Y7Ho']
169
+ },
170
+ 'MS': {
171
+ 'train': ['Hl-__g2gn_A', 'WG0MBPpPC6I', 'LRw_obCPUt0', '37rzWOQsNIw'],
172
+ 'val': ['Yi4Ij2NM7U4']
173
+ },
174
+ 'PK': {
175
+ 'train': ['GsAD1KT1xo8', 'XkqCExn6_Us', 'b626MiF1ew4', 'PJrm840pAUI'],
176
+ 'val': ['cjibtmSLxQ4']
177
+ },
178
+ 'PR': {
179
+ 'train': ['RBCABdttQmI', 'z_6gVvQb2d0', '4wU_LUjG5Ic', '91IHQYk1IQM'],
180
+ 'val': ['fWutDQy1nnY']
181
+ },
182
+ 'VT': {
183
+ 'train': ['gzDbaEs1Rlg', 'XzYM3PfTM4w', '98MoyGZKHXc', 'AwmHb44_ouw'],
184
+ 'val': ['J0nA4VgnoCo']
185
+ },
186
+ 'VU': {
187
+ 'train': ['akI8YFjEmUw', 'HT5vyqe0Xaw', 'vdmoEJ5YbrQ', 'xwqBXPGE9pQ'],
188
+ 'val': ['sTEELN-vY30']
189
+ }
190
+ }
main/config_qfvs.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // "max_segment_num": 20,
3
+ // "max_frame_num": 200,
4
+
5
+ // "train_videos": null,
6
+ // "test_videos": null,
7
+ // "top_percent": 0.02,
8
+
9
+ // "vid_feature": "fps1",
10
+ // "txt_feature": "query",
11
+ // "txt_max_len": 5,
12
+
13
+ // "factor": null
14
+ }
main/dataset.py ADDED
@@ -0,0 +1,1261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import h5py
4
+ import nncore
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import random
10
+ import logging
11
+ from os.path import join, exists
12
+ from nncore.dataset import DATASETS
13
+ from nncore.parallel import DataContainer
14
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
15
+ from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
16
+ from utils.tensor_utils import pad_sequences_1d
17
+ from utils.span_utils import span_xx_to_cxw
18
+ from random import shuffle
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class DatasetVLP(Dataset):
23
+ Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
24
+ """One line in data loaded from data_path."
25
+ {
26
+ "qid": 7803,
27
+ "query": "Man in gray top walks from outside to inside.",
28
+ "duration": 150,
29
+ "vid": "RoripwjYFp8_360.0_510.0",
30
+ "relevant_clip_ids": [13, 14, 15, 16, 17],
31
+ "relevant_windows": [[26, 36]]
32
+ }
33
+ """
34
+ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
35
+ q_feat_type="last_hidden_state",
36
+ max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
37
+ normalize_v=True, normalize_t=True, load_labels=True,
38
+ clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
39
+ use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
40
+ self.dset_name = dset_name
41
+ self.data_path = data_path
42
+ self.data_ratio = data_ratio
43
+ self.v_feat_dirs = v_feat_dirs \
44
+ if isinstance(v_feat_dirs, list) else [v_feat_dirs]
45
+ self.q_feat_dir = q_feat_dir
46
+ self.q_feat_type = q_feat_type
47
+ self.v_feat_dim = v_feat_dim
48
+ self.q_feat_dim = q_feat_dim
49
+ self.max_q_l = max_q_l
50
+ self.max_v_l = max_v_l
51
+ self.ctx_mode = ctx_mode
52
+ self.use_tef = "tef" in ctx_mode
53
+ self.use_video = "video" in ctx_mode
54
+ self.normalize_t = normalize_t
55
+ self.normalize_v = normalize_v
56
+ self.load_labels = load_labels
57
+ self.clip_len = clip_len
58
+ self.fix_len = fix_len
59
+ self.max_windows = max_windows # maximum number of windows to use as labels
60
+ self.span_loss_type = span_loss_type
61
+ self.txt_drop_ratio = txt_drop_ratio
62
+ self.use_cache = use_cache
63
+ self.add_easy_negative = add_easy_negative
64
+ self.easy_negative_only = easy_negative_only
65
+
66
+ self.vlp_mapping = {
67
+ # 'data/qvhighlights/metadata/qvhighlights_asr.jsonl': {
68
+ # 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '_asr', 'type': 'interval',
69
+ # },
70
+ # 'data/ego4d/metadata/point_train_1m.jsonl': {
71
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
72
+ # },
73
+ # 'data/ego4d/metadata/point_train_1m_0.1p.jsonl': {
74
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
75
+ # },
76
+ # 'data/ego4d/metadata/point_train_1m_0.2p.jsonl': {
77
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
78
+ # },
79
+ # 'data/ego4d/metadata/point_train_1m_0.5p.jsonl': {
80
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
81
+ # },
82
+ # 'data/ego4d/metadata/point_train_1m_0.75p.jsonl': {
83
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
84
+ # },
85
+ # 'data/ego4d/metadata/point_train_2m.jsonl': {
86
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
87
+ # },
88
+ # 'data/ego4d/metadata/point_train_1m_egoclip.jsonl': {
89
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
90
+ # },
91
+ # 'data/hacs/metadata/hacs_train_cs.jsonl': {
92
+ # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '_cs', 'type': 'curve',
93
+ # },
94
+ # 'data/hacs/metadata/hacs_train.jsonl': {
95
+ # 'dset_name': 'hacs', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
96
+ # },
97
+ # 'data/videocc/metadata/train_300k.jsonl': {
98
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
99
+ # },
100
+ # 'data/videocc/metadata/train_600k.jsonl': {
101
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
102
+ # },
103
+ # 'data/videocc/metadata/train_600k_0.1p.jsonl': {
104
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
105
+ # },
106
+ # 'data/videocc/metadata/train_600k_0.2p.jsonl': {
107
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
108
+ # },
109
+ # 'data/videocc/metadata/train_600k_0.5p.jsonl': {
110
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
111
+ # },
112
+ # 'data/videocc/metadata/train_600k_0.75p.jsonl': {
113
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
114
+ # },
115
+ # 'data/videocc/metadata/train_900k.jsonl': {
116
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
117
+ # },
118
+ # 'data/ego4d/metadata/concept_train_top10_window.jsonl': {
119
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
120
+ # },
121
+ # 'data/ego4d/metadata/concept_train_top5_window.jsonl': {
122
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
123
+ # },
124
+ # 'data/ego4d/metadata/concept_train_top5_window_0.1p.jsonl': {
125
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
126
+ # },
127
+ # 'data/ego4d/metadata/concept_train_top5_window_0.2p.jsonl': {
128
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
129
+ # },
130
+ # 'data/ego4d/metadata/concept_train_top5_window_0.5p.jsonl': {
131
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
132
+ # },
133
+ # 'data/ego4d/metadata/concept_train_top5_window_0.75p.jsonl': {
134
+ # 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_concept', 'type': 'curve',
135
+ # },
136
+ # 'data/videocc/metadata/concept_train_top10_window.jsonl': {
137
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
138
+ # },
139
+ # 'data/videocc/metadata/concept_train_top5_window.jsonl': {
140
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
141
+ # },
142
+ # 'data/videocc/metadata/concept_train_top5_window_0.1p.jsonl': {
143
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
144
+ # },
145
+ # 'data/videocc/metadata/concept_train_top5_window_0.2p.jsonl': {
146
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
147
+ # },
148
+ # 'data/videocc/metadata/concept_train_top5_window_0.5p.jsonl': {
149
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
150
+ # },
151
+ # 'data/videocc/metadata/concept_train_top5_window_0.75p.jsonl': {
152
+ # 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
153
+ # },
154
+ #
155
+ # pre-training
156
+ 'data/ego4d/metadata/point_egoclip_wo_val.jsonl': {
157
+ 'dset_name': 'ego4d', 'v_feat_suffix': '_point', 'q_feat_suffix': '_point', 'type': 'point',
158
+ },
159
+ 'data/videocc/metadata/interval_900k.jsonl': {
160
+ 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
161
+ },
162
+ 'data/videocc/metadata/curve_5_window.jsonl': {
163
+ 'dset_name': 'videocc', 'v_feat_suffix': '', 'q_feat_suffix': '_concept', 'type': 'curve',
164
+ },
165
+ # downstream
166
+ 'data/qvhighlights/metadata/qvhighlights_train.jsonl': {
167
+ 'dset_name': 'qvhighlights', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'curve',
168
+ },
169
+ 'data/charades/metadata/charades_train.jsonl': {
170
+ 'dset_name': 'charades', 'v_feat_suffix': '_2', 'q_feat_suffix': '', 'type': 'interval',
171
+ },
172
+ 'data/ego4d/metadata/nlq_train.jsonl': {
173
+ 'dset_name': 'ego4d', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
174
+ },
175
+ 'data/tacos/metadata/train.jsonl': {
176
+ 'dset_name': 'tacos', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
177
+ },
178
+ 'data/anet/metadata/train.jsonl': {
179
+ 'dset_name': 'anet', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
180
+ },
181
+ 'data/didemo/metadata/train.jsonl': {
182
+ 'dset_name': 'didemo', 'v_feat_suffix': '', 'q_feat_suffix': '', 'type': 'interval',
183
+ },
184
+ }
185
+
186
+ if "val" in data_path or "test" in data_path:
187
+ assert txt_drop_ratio == 0
188
+
189
+ # checks
190
+ assert q_feat_type in self.Q_FEAT_TYPES
191
+
192
+ # data
193
+ self.data = self.load_data()
194
+
195
+ self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
196
+ t_feat_type = q_feat_dir.split('/')[-1]
197
+
198
+ if self.use_cache > 0:
199
+ print('Loading the off-line features...')
200
+ dset_dir = os.path.join('data', self.dset_name)
201
+ vid_keys = [meta['vid'] for meta in self.data]
202
+ qid_keys = [meta['qid'] for meta in self.data]
203
+
204
+ self.vid_cache = {}
205
+ for v_feat_type in self.v_feat_types:
206
+ assert 'vid' in v_feat_type
207
+ with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
208
+ self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
209
+
210
+ assert 'txt' in t_feat_type
211
+ self.txt_cache = {}
212
+ with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
213
+ for key in tqdm(qid_keys):
214
+ try:
215
+ self.txt_cache[key] = f[str(key)][:]
216
+ except:
217
+ logger.info(f"text {key} is not in the cache.")
218
+
219
+ def load_data(self):
220
+ # datalist = load_jsonl(self.data_path[0])
221
+ datalist = []
222
+ for dset_path in self.data_path:
223
+ dset_info = self.vlp_mapping[dset_path]
224
+ dset_list = load_jsonl(dset_path)
225
+ for x in dset_list: x.update(dset_info)
226
+ datalist += dset_list
227
+ n_examples = int(len(datalist))
228
+ if self.data_ratio != 1:
229
+ n_examples = int(len(datalist) * self.data_ratio)
230
+ shuffle(datalist)
231
+ datalist = datalist[:n_examples]
232
+ logger.info("Using {}% of the data: {} examples"
233
+ .format(self.data_ratio * 100, n_examples))
234
+ return datalist
235
+
236
+ def __len__(self):
237
+ return len(self.data)
238
+
239
+ def __getitem__(self, index):
240
+ meta = self.data[index]
241
+
242
+ model_inputs = dict()
243
+ model_inputs["query_feat"] = self._get_query_feat_by_qid(meta) # (Dq, ) or (Lq, Dq)
244
+
245
+ if self.use_video:
246
+ model_inputs["video_feat"] = self._get_video_feat_by_vid(meta) # (Lv, Dv)
247
+ ctx_l = len(model_inputs["video_feat"])
248
+ else:
249
+ ctx_l = self.max_v_l
250
+
251
+ if meta['dset_name'] in ['hacs', 'ego4d', 'activitynet']:
252
+ for i, window_i in enumerate(meta["relevant_windows"]):
253
+ if window_i[1] - window_i[0] < self.clip_len:
254
+ center = (window_i[1] + window_i[0]) / 2
255
+ window_i[0] = max(0, center - 0.5 * self.clip_len)
256
+ window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
257
+ window_i[1] = max(self.clip_len, window_i[1])
258
+
259
+ model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
260
+
261
+ if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
262
+ meta["relevant_windows"] = [[0, 150]]
263
+ relevant_windows = torch.Tensor(meta["relevant_windows"])
264
+
265
+ # assign the nearest window for each timestamp i.e., qvhighlights.
266
+ num_vid_seq = model_inputs["timestamp"].shape[0]
267
+ num_windows = relevant_windows.shape[0]
268
+
269
+ relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
270
+ relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
271
+ model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
272
+
273
+ if meta['qid'] is not None:
274
+ nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
275
+ diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
276
+ diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
277
+ assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
278
+ if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
279
+ nn_window_ts = relevant_windows_ts.squeeze(1)
280
+ else:
281
+ nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
282
+
283
+ model_inputs["span_labels_nn"] = nn_window_ts
284
+ model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
285
+
286
+ # for activitynet.
287
+ if model_inputs["timestamp_window"].sum() < 1:
288
+ idx = int(meta['relevant_windows'][0][0] / self.clip_len)
289
+ idx = max(0, min(idx, ctx_l-1))
290
+ model_inputs["timestamp_window"][idx] = 1
291
+
292
+ if self.use_tef:
293
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
294
+ tef_ed = tef_st + 1.0 / ctx_l
295
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
296
+ if self.use_video:
297
+ model_inputs["video_feat"] = torch.cat(
298
+ [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
299
+ else:
300
+ model_inputs["video_feat"] = tef
301
+
302
+ if self.load_labels:
303
+ model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
304
+ if 'saliency_scores' in meta.keys():
305
+ # this is for highlight-only task
306
+ model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
307
+ limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
308
+ model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
309
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
310
+ self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
311
+ # pdb.set_trace()
312
+ else:
313
+ model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
314
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
315
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
316
+ model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
317
+
318
+ if 'type' in meta.keys():
319
+ if meta['type'] == 'point':
320
+ model_inputs['weight_ablation'] = torch.tensor([0, 0, 1, 0, 0])
321
+ if meta['type'] == 'interval':
322
+ model_inputs['weight_ablation'] = torch.tensor([1, 1, 0, 0, 0])
323
+ if meta['type'] == 'curve':
324
+ model_inputs['weight_ablation'] = torch.tensor([0, 0, 0, 1, 1])
325
+
326
+ return dict(meta=meta, model_inputs=model_inputs)
327
+
328
+ def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
329
+ gt_st = int(gt_window[0] / self.clip_len)
330
+ gt_st = min(gt_st, ctx_l-1)
331
+ gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
332
+ if gt_st > gt_ed:
333
+ # gt_st = gt_ed
334
+ gt_ed = gt_st
335
+
336
+ if gt_st != gt_ed:
337
+ pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
338
+ else:
339
+ pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
340
+
341
+ neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
342
+ # neg_clip_indices = random.sample(neg_pool, k=max_n)
343
+
344
+ try:
345
+ neg_clip_indices = random.sample(neg_pool, k=max_n)
346
+ except:
347
+ neg_clip_indices = pos_clip_indices
348
+
349
+ return pos_clip_indices, neg_clip_indices
350
+
351
+ def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
352
+ """Sum the scores from the three annotations, then take the two clips with the
353
+ maximum scores as positive, and two with the minimum scores as negative.
354
+ Args:
355
+ rel_clip_ids: list(int), list of relevant clip ids
356
+ scores: list([anno1_score, anno2_score, anno3_score]),
357
+ ctx_l: int
358
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
359
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
360
+ """
361
+ # indices inside rel_clip_ids
362
+ scores = np.array(scores) # (#rel_clips, 3)
363
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
364
+ sort_indices = np.argsort(agg_scores) # increasing
365
+
366
+ # indices in the whole video
367
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
368
+ # much troubles since this should be rarely used.
369
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
370
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
371
+
372
+ if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
373
+ hard_neg_clip_indices = hard_pos_clip_indices
374
+
375
+ easy_pos_clip_indices = []
376
+ easy_neg_clip_indices = []
377
+ # pdb.set_trace()
378
+ if self.add_easy_negative > 0:
379
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
380
+ if len(easy_neg_pool) >= max_n:
381
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
382
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
383
+ else: # copy the hard ones
384
+ easy_pos_clip_indices = hard_pos_clip_indices
385
+ easy_neg_clip_indices = hard_neg_clip_indices
386
+
387
+ if self.easy_negative_only > 0:
388
+ return easy_pos_clip_indices, easy_neg_clip_indices
389
+
390
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
391
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
392
+
393
+ return pos_clip_indices, neg_clip_indices
394
+
395
+ def get_span_labels(self, windows, ctx_l):
396
+ """
397
+ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
398
+ Note a maximum of `self.max_windows` windows are used.
399
+ returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
400
+ """
401
+ if len(windows) > self.max_windows:
402
+ random.shuffle(windows)
403
+ windows = windows[:self.max_windows]
404
+ if self.span_loss_type == "l1":
405
+ windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
406
+ windows = span_xx_to_cxw(windows) # normalized windows in cxw
407
+ elif self.span_loss_type == "ce":
408
+ windows = torch.Tensor([
409
+ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
410
+ for w in windows]).long() # inclusive
411
+ else:
412
+ raise NotImplementedError
413
+ return windows
414
+
415
+ def _get_query_feat_by_qid(self, meta):
416
+ qid = meta['qid']
417
+ dset_name = meta['dset_name']
418
+ q_feat_suffix = meta['q_feat_suffix']
419
+ q_feat_dir = self.q_feat_dir + q_feat_suffix
420
+
421
+ if self.use_cache > 0:
422
+ try:
423
+ q_feat = self.txt_cache[qid]
424
+ except:
425
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
426
+ return torch.from_numpy(q_feat)
427
+
428
+ q_feat_path = os.path.join('data', dset_name, q_feat_dir, f"{qid}.npz")
429
+ try:
430
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
431
+ except:
432
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
433
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
434
+
435
+ if self.q_feat_type == "last_hidden_state":
436
+ # q_feat = q_feat[:self.max_q_l]
437
+ q_feat = q_feat
438
+ if self.normalize_t:
439
+ q_feat = l2_normalize_np_array(q_feat)
440
+ if self.txt_drop_ratio > 0:
441
+ q_feat = self.random_drop_rows(q_feat)
442
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
443
+
444
+ def random_drop_rows(self, embeddings):
445
+ """randomly mask num_drop rows in embeddings to be zero.
446
+ Args:
447
+ embeddings: np.ndarray (L, D)
448
+ """
449
+ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
450
+ if num_drop_rows > 0:
451
+ row_indices = np.random.choice(
452
+ len(embeddings), size=num_drop_rows, replace=False)
453
+ embeddings[row_indices] = 0
454
+ return embeddings
455
+
456
+ def _get_video_feat_by_vid(self, meta):
457
+ dset_name = meta['dset_name']
458
+ v_feat_suffix = meta['v_feat_suffix']
459
+ vid = meta['vid']
460
+
461
+ v_feat_list = []
462
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
463
+ v_feat_dir = _feat_dir + v_feat_suffix
464
+ if self.use_cache > 0:
465
+ _feat = self.vid_cache[feat_type][vid]
466
+ else:
467
+ _feat_path = os.path.join('data', dset_name, v_feat_dir, f"{vid}.npz")
468
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
469
+ if self.normalize_v:
470
+ _feat = l2_normalize_np_array(_feat)
471
+ v_feat_list.append(_feat)
472
+ # some features are slightly longer than the others
473
+ min_len = min([len(e) for e in v_feat_list])
474
+ v_feat_list = [e[:min_len] for e in v_feat_list]
475
+ v_feat = np.concatenate(v_feat_list, axis=1)
476
+ return torch.from_numpy(v_feat) # (Lv, D)
477
+
478
+ class DatasetMR(Dataset):
479
+ Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"]
480
+ """One line in data loaded from data_path."
481
+ {
482
+ "qid": 7803,
483
+ "query": "Man in gray top walks from outside to inside.",
484
+ "duration": 150,
485
+ "vid": "RoripwjYFp8_360.0_510.0",
486
+ "relevant_clip_ids": [13, 14, 15, 16, 17],
487
+ "relevant_windows": [[26, 36]]
488
+ }
489
+ """
490
+ def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, v_feat_dim, q_feat_dim,
491
+ q_feat_type="last_hidden_state",
492
+ max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video",
493
+ normalize_v=True, normalize_t=True, load_labels=True,
494
+ clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0,
495
+ use_cache=-1, fix_len=-1, add_easy_negative=1, easy_negative_only=-1):
496
+ self.dset_name = dset_name
497
+ self.data_path = data_path[0] if isinstance(data_path, list) else data_path
498
+ self.data_ratio = data_ratio
499
+ self.v_feat_dirs = v_feat_dirs \
500
+ if isinstance(v_feat_dirs, list) else [v_feat_dirs]
501
+ self.q_feat_dir = q_feat_dir
502
+ self.q_feat_type = q_feat_type
503
+ self.v_feat_dim = v_feat_dim
504
+ self.q_feat_dim = q_feat_dim
505
+ self.max_q_l = max_q_l
506
+ self.max_v_l = max_v_l
507
+ self.ctx_mode = ctx_mode
508
+ self.use_tef = "tef" in ctx_mode
509
+ self.use_video = "video" in ctx_mode
510
+ self.normalize_t = normalize_t
511
+ self.normalize_v = normalize_v
512
+ self.load_labels = load_labels
513
+ self.clip_len = clip_len
514
+ self.fix_len = fix_len
515
+ self.max_windows = max_windows # maximum number of windows to use as labels
516
+ self.span_loss_type = span_loss_type
517
+ self.txt_drop_ratio = txt_drop_ratio
518
+ self.use_cache = use_cache
519
+ self.add_easy_negative = add_easy_negative
520
+ self.easy_negative_only = easy_negative_only
521
+
522
+ if "val" in data_path or "test" in data_path:
523
+ assert txt_drop_ratio == 0
524
+
525
+ # checks
526
+ assert q_feat_type in self.Q_FEAT_TYPES
527
+
528
+ # data
529
+ self.data = self.load_data()
530
+
531
+ self.v_feat_types = [feat_dir.split('/')[-1] for feat_dir in self.v_feat_dirs]
532
+ t_feat_type = q_feat_dir.split('/')[-1]
533
+
534
+ if self.use_cache > 0:
535
+ print('Loading the off-line features...')
536
+ dset_dir = os.path.join('data', self.dset_name)
537
+ vid_keys = [meta['vid'] for meta in self.data]
538
+ qid_keys = [meta['qid'] for meta in self.data]
539
+
540
+ self.vid_cache = {}
541
+ for v_feat_type in self.v_feat_types:
542
+ assert 'vid' in v_feat_type
543
+ with h5py.File(os.path.join(dset_dir, 'h5py', v_feat_type + '.hdf5'), 'r') as f:
544
+ self.vid_cache[v_feat_type] = {key: f[str(key)][:] for key in tqdm(vid_keys)}
545
+
546
+ assert 'txt' in t_feat_type
547
+ self.txt_cache = {}
548
+ with h5py.File(os.path.join(dset_dir, 'h5py', t_feat_type + '.hdf5'), 'r') as f:
549
+ for key in tqdm(qid_keys):
550
+ try:
551
+ self.txt_cache[key] = f[str(key)][:]
552
+ except:
553
+ logger.info(f"text {key} is not in the cache.")
554
+
555
+ def load_data(self):
556
+ datalist = load_jsonl(self.data_path)
557
+ if self.data_ratio != 1:
558
+ n_examples = int(len(datalist) * self.data_ratio)
559
+ datalist = datalist[:n_examples]
560
+ logger.info("Using {}% of the data: {} examples"
561
+ .format(self.data_ratio * 100, n_examples))
562
+ return datalist
563
+
564
+ def __len__(self):
565
+ return len(self.data)
566
+
567
+ def __getitem__(self, index):
568
+ meta = self.data[index]
569
+
570
+ model_inputs = dict()
571
+ model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq)
572
+
573
+ if self.use_video:
574
+ model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv)
575
+ ctx_l = len(model_inputs["video_feat"])
576
+ else:
577
+ ctx_l = self.max_v_l
578
+
579
+ if self.dset_name in ['hacs', 'ego4d', 'videocc', 'activitynet']:
580
+ for i, window_i in enumerate(meta["relevant_windows"]):
581
+ if window_i[1] - window_i[0] < self.clip_len:
582
+ center = (window_i[1] + window_i[0]) / 2
583
+ window_i[0] = max(0, center - 0.5 * self.clip_len)
584
+ window_i[1] = min(float(meta['duration']), center + 0.5 * self.clip_len)
585
+ window_i[1] = max(self.clip_len, window_i[1])
586
+
587
+ model_inputs["timestamp"] = ( (torch.arange(0, ctx_l) + self.clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
588
+
589
+ if 'test' in self.data_path and 'qvhighlights' in self.dset_name:
590
+ meta["relevant_windows"] = [[0, 150]]
591
+ relevant_windows = torch.Tensor(meta["relevant_windows"])
592
+
593
+ # assign the nearest window for each timestamp i.e., qvhighlights.
594
+ num_vid_seq = model_inputs["timestamp"].shape[0]
595
+ num_windows = relevant_windows.shape[0]
596
+
597
+ relevant_windows_ts = relevant_windows / (ctx_l * self.clip_len)
598
+ relevant_windows_ts = relevant_windows_ts.unsqueeze(0).repeat(num_vid_seq, 1, 1)
599
+ model_inputs_ts = model_inputs["timestamp"].unsqueeze(1).repeat(1, num_windows, 1)
600
+
601
+ if meta['qid'] is not None:
602
+ nn_window_ts = torch.zeros_like(model_inputs["timestamp"])
603
+ diff_left = model_inputs_ts[..., 0] - relevant_windows_ts[..., 0]
604
+ diff_right = relevant_windows_ts[..., 1] - model_inputs_ts[..., 1]
605
+ assign_idx = torch.where((diff_left >= 0) * (diff_right >= 0))
606
+ if min(assign_idx[0].shape) == 0: # not assigned, happened in activitynet.
607
+ nn_window_ts = relevant_windows_ts.squeeze(1)
608
+ else:
609
+ nn_window_ts[assign_idx[0]] = relevant_windows_ts[assign_idx[0], assign_idx[1]]
610
+
611
+ model_inputs["span_labels_nn"] = nn_window_ts
612
+ model_inputs["timestamp_window"] = 1 * (model_inputs["timestamp"][:,0] >= nn_window_ts[:,0]) & (model_inputs["timestamp"][:,1] <= nn_window_ts[:,1])
613
+
614
+ # for activitynet.
615
+ if model_inputs["timestamp_window"].sum() < 1:
616
+ idx = int(meta['relevant_windows'][0][0] / self.clip_len)
617
+ idx = max(0, min(idx, ctx_l-1))
618
+ model_inputs["timestamp_window"][idx] = 1
619
+
620
+ if self.use_tef:
621
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
622
+ tef_ed = tef_st + 1.0 / ctx_l
623
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
624
+ if self.use_video:
625
+ model_inputs["video_feat"] = torch.cat(
626
+ [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2)
627
+ else:
628
+ model_inputs["video_feat"] = tef
629
+
630
+ if self.load_labels:
631
+ model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2)
632
+ if 'saliency_scores' in meta.keys():
633
+ model_inputs["saliency_scores"] = torch.zeros(ctx_l).double()
634
+ limit = meta["relevant_clip_ids"].index(ctx_l) if (np.array(meta["relevant_clip_ids"]) >= ctx_l).any() else None
635
+ model_inputs["saliency_scores"][meta["relevant_clip_ids"][:limit]] = torch.tensor(np.mean(np.array(meta["saliency_scores"][:limit]), -1))
636
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
637
+ self.get_saliency_labels(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l)
638
+ else:
639
+ model_inputs["saliency_scores"] = model_inputs["timestamp_window"]
640
+ model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"] = \
641
+ self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], ctx_l) # only one gt
642
+ model_inputs["saliency_pos_labels"] = [ random.choice(torch.where(model_inputs['saliency_scores'])[0].tolist()) ]
643
+
644
+ return dict(meta=meta, model_inputs=model_inputs)
645
+
646
+ def get_saliency_labels_sub_as_query(self, gt_window, ctx_l, max_n=1):
647
+ gt_st = int(gt_window[0] / self.clip_len)
648
+ gt_st = min(gt_st, ctx_l-1)
649
+ gt_ed = max(0, min(int(gt_window[1] / self.clip_len), ctx_l) - 1)
650
+ if gt_st > gt_ed:
651
+ gt_ed = gt_st
652
+
653
+ if gt_st != gt_ed:
654
+ pos_clip_indices = random.sample(range(gt_st, gt_ed+1), k=max_n)
655
+ else:
656
+ pos_clip_indices = [gt_st] * max_n #[gt_st, gt_st]
657
+
658
+ neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l))
659
+
660
+ try:
661
+ neg_clip_indices = random.sample(neg_pool, k=max_n)
662
+ except:
663
+ neg_clip_indices = pos_clip_indices
664
+
665
+ return pos_clip_indices, neg_clip_indices
666
+
667
+ def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1):
668
+ """Sum the scores from the three annotations, then take the two clips with the
669
+ maximum scores as positive, and two with the minimum scores as negative.
670
+ Args:
671
+ rel_clip_ids: list(int), list of relevant clip ids
672
+ scores: list([anno1_score, anno2_score, anno3_score]),
673
+ ctx_l: int
674
+ max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively.
675
+ add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids.
676
+ """
677
+ # indices inside rel_clip_ids
678
+ scores = np.array(scores) # (#rel_clips, 3)
679
+ agg_scores = np.sum(scores, 1) # (#rel_clips, )
680
+ sort_indices = np.argsort(agg_scores) # increasing
681
+
682
+ # indices in the whole video
683
+ # the min(_, ctx_l-1) here is incorrect, but should not cause
684
+ # much troubles since this should be rarely used.
685
+ hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]]
686
+ hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]]
687
+
688
+ if agg_scores[sort_indices[-1]] == agg_scores[sort_indices[0]]:
689
+ hard_neg_clip_indices = hard_pos_clip_indices
690
+
691
+ easy_pos_clip_indices = []
692
+ easy_neg_clip_indices = []
693
+
694
+ if self.add_easy_negative > 0:
695
+ easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids))
696
+ if len(easy_neg_pool) >= max_n:
697
+ easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n)
698
+ easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n)
699
+ else: # copy the hard ones
700
+ easy_pos_clip_indices = hard_pos_clip_indices
701
+ easy_neg_clip_indices = hard_neg_clip_indices
702
+
703
+ if self.easy_negative_only > 0:
704
+ return easy_pos_clip_indices, easy_neg_clip_indices
705
+
706
+ pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices
707
+ neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices
708
+ return pos_clip_indices, neg_clip_indices
709
+
710
+ def get_span_labels(self, windows, ctx_l):
711
+ """
712
+ windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive)
713
+ Note a maximum of `self.max_windows` windows are used.
714
+ returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length
715
+ """
716
+ if len(windows) > self.max_windows:
717
+ random.shuffle(windows)
718
+ windows = windows[:self.max_windows]
719
+ if self.span_loss_type == "l1":
720
+ windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx
721
+ windows = span_xx_to_cxw(windows) # normalized windows in cxw
722
+ elif self.span_loss_type == "ce":
723
+ windows = torch.Tensor([
724
+ [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1]
725
+ for w in windows]).long() # inclusive
726
+ else:
727
+ raise NotImplementedError
728
+ return windows
729
+
730
+ def _get_query_feat_by_qid(self, qid):
731
+ if self.use_cache > 0:
732
+ try:
733
+ q_feat = self.txt_cache[qid]
734
+ except:
735
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
736
+ return torch.from_numpy(q_feat)
737
+
738
+ q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
739
+ try:
740
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
741
+ except:
742
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
743
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
744
+
745
+ if self.q_feat_type == "last_hidden_state":
746
+ # q_feat = q_feat[:self.max_q_l]
747
+ q_feat = q_feat
748
+ if self.normalize_t:
749
+ q_feat = l2_normalize_np_array(q_feat)
750
+ if self.txt_drop_ratio > 0:
751
+ q_feat = self.random_drop_rows(q_feat)
752
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
753
+
754
+ def random_drop_rows(self, embeddings):
755
+ """randomly mask num_drop rows in embeddings to be zero.
756
+ Args:
757
+ embeddings: np.ndarray (L, D)
758
+ """
759
+ num_drop_rows = round(len(embeddings) * self.txt_drop_ratio)
760
+ if num_drop_rows > 0:
761
+ row_indices = np.random.choice(
762
+ len(embeddings), size=num_drop_rows, replace=False)
763
+ embeddings[row_indices] = 0
764
+ return embeddings
765
+
766
+ def _get_video_feat_by_vid(self, vid):
767
+ v_feat_list = []
768
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
769
+ if self.use_cache > 0:
770
+ _feat = self.vid_cache[feat_type][vid]
771
+ else:
772
+ _feat_path = join(_feat_dir, f"{vid}.npz")
773
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
774
+ # _feat = np.load(_feat_path)["features"][:self.max_v_l].astype(np.float32)
775
+ if self.normalize_v:
776
+ _feat = l2_normalize_np_array(_feat)
777
+ v_feat_list.append(_feat)
778
+ # some features are slightly longer than the others
779
+ min_len = min([len(e) for e in v_feat_list])
780
+ v_feat_list = [e[:min_len] for e in v_feat_list]
781
+ v_feat = np.concatenate(v_feat_list, axis=1)
782
+ return torch.from_numpy(v_feat) # (Lv, D)
783
+
784
+ class DatasetHL(Dataset):
785
+ def __init__(self,
786
+ dset_name,
787
+ domain,
788
+ data_path,
789
+ v_feat_types,
790
+ v_feat_dirs,
791
+ t_feat_dir,
792
+ use_tef=False
793
+ ):
794
+ assert dset_name in ['tvsum', 'youtube']
795
+ self.dset_name = dset_name
796
+ dset_domain = {'tvsum': TVSUM_SPLITS,
797
+ 'youtube': YOUTUBE_SPLITS}
798
+ self.splits = dset_domain[dset_name]
799
+ assert domain in self.splits.keys()
800
+
801
+ self.domain = domain
802
+ assert len(data_path) == 1
803
+ self.data_path = data_path[0] if isinstance(data_path, list) else data_path
804
+ self.v_feat_types = v_feat_types.split('_')
805
+ self.v_feat_dirs = v_feat_dirs
806
+ self.q_feat_type = "last_hidden_state"
807
+ self.q_feat_dir = t_feat_dir
808
+
809
+ self.txt_drop_ratio = 0
810
+ self.normalize_t = True
811
+ self.normalize_v = True
812
+
813
+ self.label = nncore.load(self.data_path)
814
+ self.use_tef = use_tef
815
+
816
+ self.video_id = {
817
+ k: [s for s in self.splits[domain][k] if s in self.label]
818
+ for k in ('train', 'val')
819
+ }
820
+ self.set_state('train')
821
+
822
+ def __len__(self):
823
+ return len(self.video_id[self.state])
824
+
825
+ def __getitem__(self, idx):
826
+ vid = self.get_video_id(idx)
827
+ video = self._get_video_feat_by_vid(vid)
828
+ saliency = self.get_saliency(idx)
829
+
830
+ if self.dset_name == 'youtube':
831
+ saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
832
+ elif self.dset_name == 'tvsum':
833
+ saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency > 0)[0].tolist())])
834
+ # saliency_pos_labels = torch.Tensor([random.choice(torch.where(saliency != min(saliency))[0].tolist())])
835
+ else:
836
+ raise NotImplementedError
837
+
838
+ num_clips = min(c.size(0) for c in (video, saliency))
839
+
840
+ video = video[:num_clips]
841
+ saliency = saliency[:num_clips]
842
+
843
+ if self.use_tef:
844
+ ctx_l = video.shape[0]
845
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
846
+ tef_ed = tef_st + 1.0 / ctx_l
847
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
848
+ video = torch.cat([video, tef], dim=1) # (Lv, Dv+2)
849
+
850
+ data = dict(
851
+ video=DataContainer(video),
852
+ saliency=DataContainer(saliency, pad_value=-1),
853
+ saliency_pos_labels=saliency_pos_labels)
854
+
855
+ if self.q_feat_dir is not None:
856
+ query = self._get_query_feat_by_qid(vid)
857
+ data['query'] = DataContainer(query, pad_value=float('inf'))
858
+ return data
859
+
860
+ def set_state(self, state):
861
+ self.state = 'train' if state == 'train' else 'val'
862
+
863
+ def get_video_id(self, idx):
864
+ return self.video_id[self.state][idx]
865
+
866
+ def get_video(self, idx):
867
+ video_id = self.get_video_id(idx)
868
+ video = torch.from_numpy(self.video[video_id]).float()
869
+ optic = torch.from_numpy(self.optic[video_id]).float()
870
+ return torch.cat((video, optic), dim=1)
871
+
872
+ def _get_video_feat_by_vid(self, vid):
873
+ v_feat_list = []
874
+ for feat_type, _feat_dir in zip(self.v_feat_types, self.v_feat_dirs):
875
+ # if self.use_cache > 0:
876
+ # _feat = self.vid_cache[feat_type][vid]
877
+ # else:
878
+ if True:
879
+ _feat_path = join(_feat_dir, f"{vid}.npz")
880
+ _feat = np.load(_feat_path)["features"].astype(np.float32)
881
+ if self.normalize_v:
882
+ _feat = l2_normalize_np_array(_feat)
883
+ v_feat_list.append(_feat)
884
+ # some features are slightly longer than the others
885
+ min_len = min([len(e) for e in v_feat_list])
886
+ v_feat_list = [e[:min_len] for e in v_feat_list]
887
+ v_feat = np.concatenate(v_feat_list, axis=1)
888
+ return torch.from_numpy(v_feat) # (Lv, D)
889
+
890
+ def _get_query_feat_by_qid(self, qid):
891
+ # if self.use_cache > 0:
892
+ # try:
893
+ # q_feat = self.txt_cache[qid]
894
+ # except:
895
+ # q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
896
+ # return torch.from_numpy(q_feat)
897
+
898
+ q_feat_path = join(self.q_feat_dir, f"{qid}.npz")
899
+ try:
900
+ q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32)
901
+ except:
902
+ q_feat = np.zeros((10, self.q_feat_dim)).astype(np.float32)
903
+ logger.info(f"Something wrong when loading the query feature {q_feat_path}.")
904
+
905
+ if self.q_feat_type == "last_hidden_state":
906
+ # q_feat = q_feat[:self.max_q_l]
907
+ q_feat = q_feat
908
+ if self.normalize_t:
909
+ q_feat = l2_normalize_np_array(q_feat)
910
+ if self.txt_drop_ratio > 0:
911
+ q_feat = self.random_drop_rows(q_feat)
912
+ return torch.from_numpy(q_feat) # (D, ) or (Lq, D)
913
+
914
+ def get_saliency(self, idx):
915
+ if self.dset_name == 'tvsum':
916
+ video_id = self.get_video_id(idx)
917
+ saliency = torch.Tensor(self.label[video_id]['anno'])
918
+
919
+ # top-5 saliency scores as a threshold.
920
+ # saliency_tmp = saliency.mean(1)
921
+ # topk = int(saliency_tmp.shape[0] * 0.1)
922
+ # th = saliency_tmp[torch.sort(saliency_tmp)[1][-topk]] # v4
923
+ # saliency = saliency_tmp - th
924
+
925
+ # saliency_tmp = saliency.mean(1) # med
926
+ # th = saliency_tmp.median()
927
+ # saliency = saliency_tmp - th
928
+
929
+ saliency = (saliency - saliency.mean()).mean(dim=1)
930
+ # saliency = (saliency.sum(dim=1) - 20) / 80 # v2
931
+
932
+ elif self.dset_name == 'youtube':
933
+ video_id = self.get_video_id(idx)
934
+ saliency = [1 if s > 0 else 0 for s in self.label[video_id]['match']]
935
+ else:
936
+ raise NotImplementedError
937
+ return torch.Tensor(saliency)
938
+
939
+ def evaluate(self, blob, k=5, save_dir=None, **kwargs):
940
+ # blob = nncore.to_dict_of_list(blob)
941
+ collected = []
942
+
943
+ if save_dir is not None:
944
+ import json
945
+ with open(os.path.join(save_dir, self.dset_name, self.domain +'.jsonl'), 'w') as f:
946
+ for idx, score in enumerate(blob):
947
+ video_id = self.get_video_id(idx)
948
+ entry = {'vid':video_id, 'pred': score[0].tolist(), 'gt': self.get_saliency(idx).tolist(),
949
+ 'duration': int(self.label[video_id]['frames']) / int(self.label[video_id]['fps']),
950
+ 'domain': self.label[video_id]['domain'], 'fps': self.label[video_id]['fps']}
951
+ if self.dset_name == 'tvsum':
952
+ entry.update({'title':self.label[video_id]['title']})
953
+ if self.dset_name == 'youtube':
954
+ entry.update({'clip':self.label[video_id]['clip']})
955
+ f.write(json.dumps(entry) + '\n')
956
+
957
+ if self.dset_name == 'tvsum':
958
+ for i in range(20):
959
+ video_ap = []
960
+ for idx, score in enumerate(blob):
961
+ inds = torch.argsort(score[0], descending=True)
962
+ video_id = self.get_video_id(idx)
963
+ label = torch.Tensor(self.label[video_id]['anno'])[:, i]
964
+ label = torch.where(label > label.median(), 1.0, .0)
965
+ label = label[inds].tolist()[:k]
966
+
967
+ if (num_gt := sum(label)) == 0:
968
+ video_ap.append(0)
969
+ continue
970
+
971
+ hits = ap = rec = 0
972
+ prc = 1
973
+
974
+ for j, gt in enumerate(label):
975
+ hits += gt
976
+ _rec = hits / num_gt
977
+ _prc = hits / (j + 1)
978
+ ap += (_rec - rec) * (prc + _prc) / 2
979
+ rec, prc = _rec, _prc
980
+ video_ap.append(ap)
981
+ collected.append(sum(video_ap) / len(video_ap))
982
+
983
+ elif self.dset_name == 'youtube':
984
+ for idx, score in enumerate(blob):
985
+ inds = torch.argsort(score[0], descending=True)
986
+ label = self.get_saliency(idx)[inds].tolist()
987
+
988
+ if (num_gt := sum(label)) == 0:
989
+ collected.append(0)
990
+ continue
991
+
992
+ hits = ap = rec = 0
993
+ prc = 1
994
+
995
+ for i, gt in enumerate(label):
996
+ hits += gt
997
+ _rec = hits / num_gt
998
+ _prc = hits / (i + 1)
999
+ ap += (_rec - rec) * (prc + _prc) / 2
1000
+ rec, prc = _rec, _prc
1001
+ collected.append(ap)
1002
+ else:
1003
+ raise NotImplementedError
1004
+
1005
+ mean_ap = sum(collected) / len(collected)
1006
+ results = dict(mAP=round(mean_ap, 5))
1007
+ return results
1008
+
1009
+ class DatasetQFVS(Dataset):
1010
+ def __init__(self,config, use_tef=True):
1011
+ # pdb.set_trace()
1012
+ self.config=config
1013
+ self.dataset=[]
1014
+ self.use_tef=use_tef
1015
+
1016
+ self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
1017
+
1018
+ for video_id in self.config["train_videos"]:
1019
+ for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
1020
+ for file in files:
1021
+ self.dataset.append(file[:file.find("_oracle.txt")]+"_"+str(video_id))
1022
+
1023
+ def __getitem__(self,index):
1024
+ video_id=self.dataset[index].split('_')[2]
1025
+ feat_type = self.config['vid_feature']
1026
+ # pdb.set_trace()
1027
+ feat_type = self.config['vid_feature']
1028
+ f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
1029
+ features=f['feature'][()]
1030
+ # dim=features.shape[-1]
1031
+ # features=features.reshape(-1, dim)
1032
+ # seg_len=f['seg_len'][()]
1033
+ dim = features.shape[-1]
1034
+ ctx_l = features.shape[0]
1035
+ seg_len = np.ones(ctx_l)
1036
+
1037
+ # mask = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
1038
+ # for j in range(len(seg_len)):
1039
+ # for k in range(seg_len[j]):
1040
+ # mask[j][k] = 1
1041
+
1042
+ # ctx_l = seg_len.sum()
1043
+ features = torch.from_numpy(features)
1044
+ # features = features[mask, :]
1045
+
1046
+ if self.use_tef:
1047
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
1048
+ tef_ed = tef_st + 1.0 / ctx_l
1049
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
1050
+ features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
1051
+
1052
+ transfer={"Cupglass":"Glass",
1053
+ "Musicalinstrument":"Instrument",
1054
+ "Petsanimal":"Animal"}
1055
+
1056
+ concept1,concept2=self.dataset[index].split('_')[0:2]
1057
+
1058
+ concept1_GT=torch.zeros(ctx_l)
1059
+ concept2_GT=torch.zeros(ctx_l)
1060
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
1061
+ lines=f.readlines()
1062
+ for index,line in enumerate(lines):
1063
+ concepts=line.strip().split(',')
1064
+ if concept1 in concepts:
1065
+ concept1_GT[index]=1
1066
+ if concept2 in concepts:
1067
+ concept2_GT[index]=1
1068
+
1069
+ # shot_num=seg_len.sum()
1070
+ # mask_GT=torch.zeros(ctx_l)
1071
+ # for i in range(shot_num):
1072
+ # mask_GT[i]=1
1073
+ mask_GT=torch.ones(ctx_l)
1074
+
1075
+ oracle_summary = torch.zeros(ctx_l)
1076
+ GT_summary_shots = []
1077
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
1078
+ for line in f.readlines():
1079
+ GT_summary_shots.append(int(line.strip()))
1080
+ GT_summary_shots = [x - 1 for x in GT_summary_shots]
1081
+ for element in GT_summary_shots:
1082
+ oracle_summary[element] = 1
1083
+
1084
+ if concept1 in transfer:
1085
+ concept1=transfer[concept1]
1086
+ if concept2 in transfer:
1087
+ concept2=transfer[concept2]
1088
+ concept1=self.embedding[concept1]
1089
+ concept2=self.embedding[concept2]
1090
+
1091
+ try:
1092
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
1093
+ except:
1094
+ saliency_pos_labels_1 = torch.Tensor(0)
1095
+
1096
+ try:
1097
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
1098
+ except:
1099
+ saliency_pos_labels_2 = torch.Tensor(0)
1100
+
1101
+ try:
1102
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
1103
+ except:
1104
+ saliency_pos_labels_oracle = torch.Tensor(0)
1105
+
1106
+ return {
1107
+ 'features':features,
1108
+ 'seg_len':torch.from_numpy(seg_len),
1109
+ 'concept1_GT':concept1_GT,
1110
+ 'concept2_GT':concept2_GT,
1111
+ 'mask_GT':mask_GT,
1112
+ 'oracle_summary':oracle_summary,
1113
+ 'tokens_pad1':torch.from_numpy(concept1),
1114
+ 'tokens_pad2':torch.from_numpy(concept2),
1115
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
1116
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
1117
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
1118
+ }
1119
+
1120
+ def __len__(self):
1121
+ return len(self.dataset)
1122
+
1123
+ def start_end_collate_mr(batch):
1124
+ batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
1125
+
1126
+ model_inputs_keys = batch[0]["model_inputs"].keys()
1127
+ batched_data = dict()
1128
+ for k in model_inputs_keys:
1129
+ if k == "span_labels":
1130
+ batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch]
1131
+ continue
1132
+ if k in ["saliency_pos_labels", "saliency_neg_labels"]:
1133
+ batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch])
1134
+ continue
1135
+
1136
+ batched_data[k] = pad_sequences_1d(
1137
+ [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
1138
+ return batch_meta, batched_data
1139
+
1140
+ def start_end_collate_hl(batch):
1141
+ model_inputs_keys = batch[0].keys()
1142
+
1143
+ batched_data = dict()
1144
+ for k in model_inputs_keys:
1145
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
1146
+ return batched_data
1147
+
1148
+ def start_end_collate_qfvs(batch):
1149
+ model_inputs_keys = batch[0].keys()
1150
+
1151
+ batched_data = dict()
1152
+ for k in model_inputs_keys:
1153
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
1154
+
1155
+ return batched_data
1156
+
1157
+ def prepare_batch_inputs_mr(batched_model_inputs, device, non_blocking=False):
1158
+ model_inputs = dict(
1159
+ src_txt=batched_model_inputs["query_feat"][0].to(device, non_blocking=non_blocking),
1160
+ src_txt_mask=batched_model_inputs["query_feat"][1].to(device, non_blocking=non_blocking),
1161
+ src_vid=batched_model_inputs["video_feat"][0].to(device, non_blocking=non_blocking),
1162
+ src_vid_mask=batched_model_inputs["video_feat"][1].to(device, non_blocking=non_blocking),
1163
+ )
1164
+ targets = {}
1165
+ targets['timestamp'] = batched_model_inputs["timestamp"][0].to(device, non_blocking=non_blocking)
1166
+ targets['timestamp_mask'] = batched_model_inputs["timestamp"][1].to(device, non_blocking=non_blocking)
1167
+ targets['timestamp_window'] = batched_model_inputs["timestamp_window"][0].to(device, non_blocking=non_blocking)
1168
+ targets['span_labels_nn'] = batched_model_inputs["span_labels_nn"][0].to(device, non_blocking=non_blocking)
1169
+
1170
+ if 'saliency_scores' in batched_model_inputs.keys():
1171
+ targets['saliency_scores'] = batched_model_inputs["saliency_scores"][0].to(device, non_blocking=non_blocking)
1172
+
1173
+ if "span_labels" in batched_model_inputs:
1174
+ targets["span_labels"] = [
1175
+ dict(spans=e["spans"].to(device, non_blocking=non_blocking))
1176
+ for e in batched_model_inputs["span_labels"]
1177
+ ]
1178
+ if "saliency_pos_labels" in batched_model_inputs:
1179
+ for name in ["saliency_pos_labels", "saliency_neg_labels"]:
1180
+ targets[name] = batched_model_inputs[name].to(device, non_blocking=non_blocking)
1181
+
1182
+ if "weight_ablation" in batched_model_inputs:
1183
+ targets["weight_ablation"] = batched_model_inputs["weight_ablation"][0].to(device, non_blocking=non_blocking)
1184
+
1185
+ targets = None if len(targets) == 0 else targets
1186
+ return model_inputs, targets
1187
+
1188
+ def prepare_batch_inputs_hl(batched_model_inputs, device='cuda', non_blocking=False):
1189
+ src_vid = batched_model_inputs['video'][0].to(device, non_blocking=non_blocking)
1190
+ src_vid_mask = batched_model_inputs['video'][1].bool().to(device, non_blocking=non_blocking)
1191
+ src_txt = batched_model_inputs['query'][0].to(device, non_blocking=non_blocking) \
1192
+ if 'query' in batched_model_inputs.keys() else None
1193
+ src_txt_mask = batched_model_inputs['query'][1].bool().to(device, non_blocking=non_blocking) \
1194
+ if 'query' in batched_model_inputs.keys() else None
1195
+
1196
+ model_inputs = dict(
1197
+ src_vid=src_vid, src_vid_mask=src_vid_mask,
1198
+ src_txt=src_txt, src_txt_mask=src_txt_mask)
1199
+
1200
+ # if 'audio' in batched_model_inputs.keys():
1201
+ # src_aud = batched_model_inputs['audio'][0].bool().to(device, non_blocking=non_blocking)
1202
+ # src_aud_mask = batched_model_inputs['audio'][1].bool().to(device, non_blocking=non_blocking)
1203
+ # model_inputs['src_aud']=src_aud; model_inputs['src_aud_mask']=src_aud_mask;
1204
+
1205
+ targets = {}
1206
+ saliency = batched_model_inputs['saliency'][0].to(device, non_blocking=non_blocking)
1207
+ saliency_pos_labels = batched_model_inputs['saliency_pos_labels'][0].to(device, non_blocking=non_blocking)
1208
+
1209
+ targets['saliency_scores'] = saliency
1210
+ targets['saliency_pos_labels'] = saliency_pos_labels.long()
1211
+ targets['timestamp_mask'] = batched_model_inputs["video"][1].to(device, non_blocking=non_blocking)
1212
+ targets['timestamp_window'] = 1 * (saliency > 0)
1213
+
1214
+ return model_inputs, targets
1215
+
1216
+ def prepare_batch_inputs_qfvs(data, config, eval=False):
1217
+ if not eval:
1218
+ features, mask, seg_len, \
1219
+ concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
1220
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
1221
+ saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
1222
+ data['features'][0], data['features'][1], data['seg_len'][0],\
1223
+ data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
1224
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
1225
+ data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
1226
+ else:
1227
+ features, mask, seg_len, \
1228
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
1229
+ data['features'][0], data['features'][1], data['seg_len'][0],\
1230
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
1231
+
1232
+ # preprocess for vid input.
1233
+ seq = features.to('cuda')
1234
+ mask = mask.to('cuda')
1235
+
1236
+ # for txt input.
1237
+ src_txt_1 = src_txt_1.to(torch.float32).to('cuda')
1238
+ src_txt_2 = src_txt_2.to(torch.float32).to('cuda')
1239
+ src_txt_mask_1 = src_txt_mask_1.to('cuda')
1240
+ src_txt_mask_2 = src_txt_mask_2.to('cuda')
1241
+
1242
+ src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
1243
+ src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
1244
+
1245
+ model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
1246
+ model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
1247
+ model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
1248
+
1249
+ if not eval:
1250
+ targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
1251
+ targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
1252
+ targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
1253
+
1254
+ targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
1255
+ targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
1256
+ targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
1257
+
1258
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, \
1259
+ targets_1, targets_2, targets_oracle, mask_GT
1260
+ else:
1261
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, mask
main/dataset_qfvs.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import h5py
4
+ import nncore
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import random
10
+ import logging
11
+ from os.path import join, exists
12
+ from nncore.dataset import DATASETS
13
+ from nncore.parallel import DataContainer
14
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
15
+ from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
16
+ from utils.tensor_utils import pad_sequences_1d
17
+ from utils.span_utils import span_xx_to_cxw
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ class DatasetQFVS(Dataset):
22
+ def __init__(self,config, use_tef=True):
23
+ # pdb.set_trace()
24
+ self.config=config
25
+ self.dataset=[]
26
+ self.use_tef=use_tef
27
+
28
+ self.embedding=load_pickle(f"./data/qfvs/txt_clip/{self.config['txt_feature']}.pkl")
29
+
30
+ self.transfer={"Cupglass":"Glass",
31
+ "Musicalinstrument":"Instrument",
32
+ "Petsanimal":"Animal"}
33
+
34
+ self.f_dict = {}
35
+ feat_type = self.config['vid_feature']
36
+
37
+ for video_id in self.config["train_videos"]:
38
+ self.f_dict[str(video_id)] = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
39
+ for _ , _, files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
40
+ for file in files:
41
+ self.dataset.append(['Oracle', file[:file.find("_oracle.txt")]+"_"+str(video_id)])
42
+
43
+ if self.config['qfvs_dense_shot'] > 0:
44
+ dense_concept = {}
45
+ feat_type = self.config['vid_feature']
46
+ feat=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
47
+ features=feat['features'][()]
48
+ seg_len=feat['seg_len'][()]
49
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+str(video_id)+"/P0"+str(video_id)+".txt","r") as f:
50
+ lines=f.readlines()
51
+ for index,line in enumerate(lines):
52
+ concepts=line.strip().split(',')
53
+ for concept in concepts:
54
+ if concept in self.transfer:
55
+ concept= self.transfer[concept]
56
+ if concept not in dense_concept:
57
+ # dense_concept[concept] = torch.zeros(seg_len.sum())
58
+ dense_concept[concept] = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
59
+ else:
60
+ dense_concept[concept][index] = 1
61
+
62
+ for key, value in dense_concept.items():
63
+ if value.sum().item() > 0:
64
+ self.dataset.append([video_id, key, value])
65
+
66
+ def __getitem__(self, index):
67
+ if self.dataset[index][0] == 'Oracle':
68
+ return self.get_oracle(index)
69
+ else:
70
+ return self.get_dense(index)
71
+
72
+ def get_dense(self,index):
73
+ video_id=str(self.dataset[index][0])
74
+ f = self.f_dict[video_id]
75
+ # feat_type = self.config['vid_feature']
76
+ # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
77
+ features=f['features'][()]
78
+ seg_len=f['seg_len'][()]
79
+
80
+ dim = features.shape[-1]
81
+
82
+ mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
83
+ for j in range(len(seg_len)):
84
+ for k in range(seg_len[j]):
85
+ mask_GT[j][k] = 1
86
+
87
+ features = torch.from_numpy(features)
88
+
89
+ concept1 = concept2 = self.dataset[index][1]
90
+ concept1_GT = concept2_GT = oracle_summary = self.dataset[index][2]
91
+
92
+ if concept1 in self.transfer:
93
+ concept1=self.transfer[concept1]
94
+ if concept2 in self.transfer:
95
+ concept2=self.transfer[concept2]
96
+ concept1=self.embedding[concept1]
97
+ concept2=self.embedding[concept2]
98
+
99
+ concept1 = l2_normalize_np_array(concept1)
100
+ concept2 = l2_normalize_np_array(concept2)
101
+
102
+ try:
103
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
104
+ except:
105
+ saliency_pos_labels_1 = torch.Tensor(0)
106
+
107
+ try:
108
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
109
+ except:
110
+ saliency_pos_labels_2 = torch.Tensor(0)
111
+
112
+ try:
113
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
114
+ except:
115
+ saliency_pos_labels_oracle = torch.Tensor(0)
116
+
117
+ return {
118
+ 'features':features,
119
+ 'seg_len':torch.from_numpy(seg_len),
120
+ 'concept1_GT':concept1_GT,
121
+ 'concept2_GT':concept2_GT,
122
+ 'mask_GT':mask_GT,
123
+ 'oracle_summary':oracle_summary,
124
+ 'tokens_pad1':torch.from_numpy(concept1),
125
+ 'tokens_pad2':torch.from_numpy(concept2),
126
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
127
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
128
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
129
+ }
130
+
131
+ def get_oracle(self,index):
132
+ video_id=self.dataset[index][1].split('_')[2]
133
+ f = self.f_dict[video_id]
134
+ # video_id=self.dataset[index][1].split('_')[2]
135
+ # feat_type = self.config['vid_feature']
136
+ # f=h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5','r')
137
+ features=f['features'][()]
138
+ seg_len=f['seg_len'][()]
139
+
140
+ dim = features.shape[-1]
141
+
142
+ mask_GT = torch.zeros(self.config["max_segment_num"], self.config["max_frame_num"], dtype=torch.bool)
143
+ for j in range(len(seg_len)):
144
+ for k in range(seg_len[j]):
145
+ mask_GT[j][k] = 1
146
+
147
+ features = torch.from_numpy(features)
148
+
149
+ concept1,concept2=self.dataset[index][1].split('_')[0:2]
150
+
151
+ concept1_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
152
+ concept2_GT=torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
153
+ # concept1_GT=torch.zeros(seg_len.sum())
154
+ # concept2_GT= torch.zeros(seg_len.sum())
155
+ with open("./data/qfvs/metadata/origin_data/Dense_per_shot_tags/P0"+video_id+"/P0"+video_id+".txt","r") as f:
156
+ lines=f.readlines()
157
+ for index,line in enumerate(lines):
158
+ concepts=line.strip().split(',')
159
+ if concept1 in concepts:
160
+ concept1_GT[index]=1
161
+ if concept2 in concepts:
162
+ concept2_GT[index]=1
163
+
164
+ # oracle_summary =torch.zeros(seg_len.sum())
165
+ oracle_summary = torch.zeros(self.config["max_segment_num"]*self.config["max_frame_num"])
166
+ GT_summary_shots = []
167
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+str(concept1)+"_"+str(concept2)+"_"+"oracle.txt","r") as f:
168
+ for line in f.readlines():
169
+ GT_summary_shots.append(int(line.strip()))
170
+ GT_summary_shots = [x - 1 for x in GT_summary_shots]
171
+ for element in GT_summary_shots:
172
+ oracle_summary[element] = 1
173
+
174
+ if concept1 in self.transfer:
175
+ concept1=self.transfer[concept1]
176
+ if concept2 in self.transfer:
177
+ concept2=self.transfer[concept2]
178
+ concept1=self.embedding[concept1]
179
+ concept2=self.embedding[concept2]
180
+
181
+ concept1 = l2_normalize_np_array(concept1)
182
+ concept2 = l2_normalize_np_array(concept2)
183
+
184
+ try:
185
+ saliency_pos_labels_1 = torch.Tensor([random.choice(torch.where(concept1_GT> 0)[0].tolist())])
186
+ except:
187
+ saliency_pos_labels_1 = torch.Tensor(0)
188
+
189
+ try:
190
+ saliency_pos_labels_2 = torch.Tensor([random.choice(torch.where(concept2_GT> 0)[0].tolist())])
191
+ except:
192
+ saliency_pos_labels_2 = torch.Tensor(0)
193
+
194
+ try:
195
+ saliency_pos_labels_oracle = torch.Tensor([random.choice(torch.where(oracle_summary> 0)[0].tolist())])
196
+ except:
197
+ saliency_pos_labels_oracle = torch.Tensor(0)
198
+
199
+ return {
200
+ 'features':features,
201
+ 'seg_len':torch.from_numpy(seg_len),
202
+ 'concept1_GT':concept1_GT,
203
+ 'concept2_GT':concept2_GT,
204
+ 'mask_GT':mask_GT,
205
+ 'oracle_summary':oracle_summary,
206
+ 'tokens_pad1':torch.from_numpy(concept1),
207
+ 'tokens_pad2':torch.from_numpy(concept2),
208
+ 'saliency_pos_labels_1': saliency_pos_labels_1,
209
+ 'saliency_pos_labels_2': saliency_pos_labels_2,
210
+ 'saliency_pos_labels_oracle': saliency_pos_labels_oracle,
211
+ }
212
+
213
+ def __len__(self):
214
+ return len(self.dataset)
215
+
216
+ def start_end_collate_qfvs(batch):
217
+ model_inputs_keys = batch[0].keys()
218
+
219
+ batched_data = dict()
220
+ for k in model_inputs_keys:
221
+ batched_data[k] = pad_sequences_1d([e[k].data for e in batch], dtype=torch.float32, fixed_length=None)
222
+
223
+ return batched_data
224
+
225
+ def prepare_batch_inputs_qfvs(data, config, eval=False):
226
+ if not eval:
227
+ features, mask, seg_len, \
228
+ concept1_GT, concept2_GT, mask_GT, oracle_summary_GT, \
229
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2,\
230
+ saliency_pos_labels_1, saliency_pos_labels_2, saliency_pos_labels_oracle = \
231
+ data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
232
+ data['concept1_GT'][0], data['concept2_GT'][0], data['mask_GT'][0], data['oracle_summary'][0],\
233
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1], \
234
+ data['saliency_pos_labels_1'][0], data['saliency_pos_labels_2'][0], data['saliency_pos_labels_oracle'][0],
235
+ else:
236
+ features, mask, seg_len, \
237
+ src_txt_1, src_txt_2, src_txt_mask_1, src_txt_mask_2 = \
238
+ data['features'][0], data['mask_GT'][0], data['seg_len'][0],\
239
+ data['tokens_pad1'][0], data['tokens_pad2'][0], data['tokens_pad1'][1], data['tokens_pad2'][1]
240
+
241
+ # preprocess for vid input.
242
+ mask_GT = mask.to('cuda').reshape(1, -1).bool()
243
+ seq = features.to('cuda').squeeze(0)
244
+ mask = mask.to('cuda').squeeze(0)
245
+ num_seg = seq.shape[0]
246
+
247
+ ctx_l = seq.shape[1]
248
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
249
+ tef_ed = tef_st + 1.0 / ctx_l
250
+ tef = torch.stack([tef_st, tef_ed], dim=1).to('cuda') # (Lv, 2)
251
+
252
+ tef = tef.squeeze(0).repeat(seq.shape[0], 1, 1)
253
+ seq = torch.cat([seq, tef], dim=-1)
254
+
255
+ # for txt input.
256
+ src_txt_1 = src_txt_1.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
257
+ src_txt_2 = src_txt_2.to(torch.float32).to('cuda').repeat(num_seg, 1, 1)
258
+ src_txt_mask_1 = src_txt_mask_1.to('cuda').repeat(num_seg, 1)
259
+ src_txt_mask_2 = src_txt_mask_2.to('cuda').repeat(num_seg, 1)
260
+
261
+ src_txt_oracle = torch.cat((src_txt_1, src_txt_2), dim=1).to('cuda')
262
+ src_txt_mask_oracle = torch.cat((src_txt_mask_1, src_txt_mask_2), dim=1).to('cuda')
263
+
264
+ model_inputs_1 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_1, src_txt_mask=src_txt_mask_1)
265
+ model_inputs_2 = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_2, src_txt_mask=src_txt_mask_2)
266
+ model_inputs_oracle = dict(src_vid=seq, src_vid_mask=mask, src_txt=src_txt_oracle, src_txt_mask=src_txt_mask_oracle)
267
+
268
+ # concept1_GT = concept1_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
269
+ # concept2_GT = concept2_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
270
+ # oracle_summary_GT = oracle_summary_GT.squeeze().reshape(config['max_segment_num'], config['max_frame_num'])
271
+
272
+ if not eval:
273
+ targets_1 = dict(saliency_scores=concept1_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_1.to('cuda'))
274
+ targets_2 = dict(saliency_scores=concept2_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_2.to('cuda'))
275
+ targets_oracle = dict(saliency_scores=oracle_summary_GT.to('cuda'), saliency_pos_labels=saliency_pos_labels_oracle.to('cuda'))
276
+
277
+ targets_1['timestamp_mask'] = mask; targets_1['timestamp_window'] = concept1_GT.to('cuda')
278
+ targets_2['timestamp_mask'] = mask; targets_2['timestamp_window'] = concept2_GT.to('cuda')
279
+ targets_oracle['timestamp_mask'] = mask; targets_oracle['timestamp_window'] = oracle_summary_GT.to('cuda')
280
+
281
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, \
282
+ targets_1, targets_2, targets_oracle, mask_GT
283
+ else:
284
+ return model_inputs_1, model_inputs_2, model_inputs_oracle, mask_GT
main/inference_demo.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import pprint
3
+ from tqdm import tqdm, trange
4
+ import numpy as np
5
+ import os
6
+ from collections import OrderedDict, defaultdict
7
+ from utils.basic_utils import AverageMeter
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils.data import DataLoader
13
+
14
+ from main.config import TestOptions, setup_model
15
+ from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
16
+ from eval.eval import eval_submission
17
+ from eval.postprocessing import PostProcessorDETR
18
+ from utils.basic_utils import save_jsonl, save_json
19
+ from utils.temporal_nms import temporal_nms
20
+ from utils.span_utils import span_cxw_to_xx
21
+ from utils.basic_utils import load_jsonl, load_pickle, l2_normalize_np_array
22
+
23
+ import logging
24
+ import importlib
25
+
26
+ logger = logging.getLogger(__name__)
27
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ level=logging.INFO)
30
+
31
+ def load_model():
32
+ logger.info("Setup config, data and model...")
33
+ opt = TestOptions().parse()
34
+ # pdb.set_trace()
35
+ cudnn.benchmark = True
36
+ cudnn.deterministic = False
37
+
38
+ model, criterion, _, _ = setup_model(opt)
39
+ return model
40
+
41
+ def load_data(save_dir):
42
+ vid = np.load(os.path.join(save_dir, 'vid.npz'))['features'].astype(np.float32)
43
+ txt = np.load(os.path.join(save_dir, 'txt.npz'))['features'].astype(np.float32)
44
+
45
+ vid = torch.from_numpy(l2_normalize_np_array(vid))
46
+ txt = torch.from_numpy(l2_normalize_np_array(txt))
47
+ clip_len = 2
48
+ ctx_l = vid.shape[0]
49
+
50
+ timestamp = ( (torch.arange(0, ctx_l) + clip_len / 2) / ctx_l).unsqueeze(1).repeat(1, 2)
51
+
52
+ if True:
53
+ tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
54
+ tef_ed = tef_st + 1.0 / ctx_l
55
+ tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2)
56
+ vid = torch.cat([vid, tef], dim=1) # (Lv, Dv+2)
57
+
58
+ src_vid = vid.unsqueeze(0).cuda()
59
+ src_txt = txt.unsqueeze(0).cuda()
60
+ src_vid_mask = torch.ones(src_vid.shape[0], src_vid.shape[1]).cuda()
61
+ src_txt_mask = torch.ones(src_txt.shape[0], src_txt.shape[1]).cuda()
62
+
63
+ return src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l
64
+
65
+ if __name__ == '__main__':
66
+ clip_len = 2
67
+ save_dir = '/data/home/qinghonglin/univtg/demo/tmp'
68
+
69
+ model = load_model()
70
+ src_vid, src_txt, src_vid_mask, src_txt_mask, timestamp, ctx_l = load_data(save_dir)
71
+ with torch.no_grad():
72
+ output = model(src_vid=src_vid, src_txt=src_txt, src_vid_mask=src_vid_mask, src_txt_mask=src_txt_mask)
73
+
74
+ pred_logits = output['pred_logits'][0].cpu()
75
+ pred_spans = output['pred_spans'][0].cpu()
76
+ pred_saliency = output['saliency_scores'].cpu()
77
+
78
+ pdb.set_trace()
79
+ top1 = (pred_spans + timestamp)[torch.argmax(pred_logits)] * ctx_l * clip_len
80
+ print(top1)
81
+ print(pred_saliency.argmax()*clip_len)
main/inference_hl.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import json
5
+ import pprint
6
+ import random
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.backends.cudnn as cudnn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ import sys
19
+ sys.path.append('/Users/kevin/univtg')
20
+ from main.config import BaseOptions, setup_model
21
+ from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
22
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
23
+ from utils.model_utils import count_parameters
24
+
25
+ import logging
26
+ logger = logging.getLogger(__name__)
27
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ level=logging.INFO)
30
+
31
+ def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
32
+ model.eval()
33
+
34
+ scores = []
35
+ train_val_dataset.set_state('val')
36
+ val_loader = DataLoader(
37
+ train_val_dataset,
38
+ collate_fn=start_end_collate_hl,
39
+ batch_size=opt.eval_bsz,
40
+ num_workers=opt.num_workers,
41
+ shuffle=False,
42
+ pin_memory=opt.pin_memory
43
+ )
44
+
45
+ with torch.no_grad():
46
+ for data in val_loader:
47
+ model_inputs, targets = prepare_batch_inputs_hl(data)
48
+ outputs = model(**model_inputs)
49
+ # pred_cls = outputs['pred_logits'].squeeze(-1)
50
+ # pred_cls = outputs['saliency_scores']
51
+ # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
52
+
53
+ # pdb.set_trace()
54
+ if opt.f_loss_coef == 0:
55
+ pred_cls = outputs['saliency_scores']
56
+ elif opt.s_loss_intra_coef == 0:
57
+ pred_cls = outputs['pred_logits'].squeeze(-1)
58
+ else:
59
+ if opt.eval_mode == 'add':
60
+ pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
61
+ else:
62
+ pred_cls = outputs['pred_logits'].squeeze(-1)
63
+
64
+ pred_cls = pred_cls.detach().cpu()
65
+ scores.append(pred_cls)
66
+ map = round(train_val_dataset.evaluate(scores, save_dir='./plot')['mAP'] * 100, 4)
67
+ return map
68
+
69
+ def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
70
+ logger.info(f"[Epoch {epoch_i+1}]")
71
+ model.train()
72
+ criterion.train()
73
+
74
+ train_val_dataset.set_state('train')
75
+ train_loader = DataLoader(
76
+ train_val_dataset,
77
+ collate_fn=start_end_collate_hl,
78
+ batch_size=opt.bsz,
79
+ num_workers=opt.num_workers,
80
+ shuffle=True,
81
+ pin_memory=opt.pin_memory
82
+ )
83
+
84
+ # init meters
85
+ time_meters = defaultdict(AverageMeter)
86
+ loss_meters = defaultdict(AverageMeter)
87
+
88
+ num_training_examples = len(train_loader)
89
+ timer_dataloading = time.time()
90
+ for batch_idx, batch in enumerate(train_loader):
91
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
92
+ timer_start = time.time()
93
+ model_inputs, targets = prepare_batch_inputs_hl(batch)
94
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
95
+
96
+ timer_start = time.time()
97
+ outputs = model(**model_inputs)
98
+ loss_dict = criterion(outputs, targets)
99
+ weight_dict = criterion.weight_dict
100
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
101
+ time_meters["model_forward_time"].update(time.time() - timer_start)
102
+
103
+ timer_start = time.time()
104
+ optimizer.zero_grad()
105
+ losses.backward()
106
+ if opt.grad_clip > 0:
107
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
108
+ optimizer.step()
109
+ time_meters["model_backward_time"].update(time.time() - timer_start)
110
+
111
+ loss_dict["loss_overall"] = float(losses)
112
+ for k, v in loss_dict.items():
113
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
114
+
115
+ timer_dataloading = time.time()
116
+ if opt.debug and batch_idx == 3:
117
+ break
118
+
119
+ # print/add logs
120
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
121
+ for k, v in loss_meters.items():
122
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
123
+
124
+ to_write = opt.train_log_txt_formatter.format(
125
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
126
+ epoch=epoch_i+1,
127
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
128
+ with open(opt.train_log_filepath, "a") as f:
129
+ f.write(to_write)
130
+
131
+ logger.info("Epoch time stats:")
132
+ for name, meter in time_meters.items():
133
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
134
+ logger.info(f"{name} ==> {d}")
135
+
136
+ # train in single domain.
137
+ def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
138
+ # if opt.device.type == "cuda":
139
+ # logger.info("CUDA enabled.")
140
+ # model.to(opt.device)
141
+
142
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
143
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
144
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
145
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
146
+
147
+ prev_best_score = 0.
148
+ if opt.start_epoch is None:
149
+ start_epoch = -1 if opt.eval_init else 0
150
+ else:
151
+ start_epoch = opt.start_epoch
152
+
153
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
154
+ if epoch_i > -1:
155
+ train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
156
+ lr_scheduler.step()
157
+ eval_epoch_interval = opt.eval_epoch
158
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
159
+ with torch.no_grad():
160
+ scores = eval_epoch(model, train_val_dataset, opt)
161
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
162
+ if prev_best_score < scores:
163
+ prev_best_score = scores
164
+ checkpoint = {
165
+ "model": model.state_dict(),
166
+ "optimizer": optimizer.state_dict(),
167
+ "epoch": epoch_i,
168
+ "opt": opt
169
+ }
170
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
171
+ tb_writer.close()
172
+ return prev_best_score
173
+
174
+ def start_training():
175
+ logger.info("Setup config, data and model...")
176
+ opt = BaseOptions().parse()
177
+ set_seed(opt.seed)
178
+
179
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
180
+ if opt.dset_name == "tvsum":
181
+ domain_splits = TVSUM_SPLITS.keys()
182
+ if opt.dset_name == "youtube":
183
+ domain_splits = YOUTUBE_SPLITS.keys()
184
+
185
+ scores = {}
186
+ if opt.lr_warmup > 0:
187
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
188
+ total_steps = opt.n_epoch
189
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
190
+ opt.lr_warmup = [warmup_steps, total_steps]
191
+
192
+ domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
193
+
194
+ for domain in domain_splits:
195
+ dataset_config = dict(
196
+ dset_name=opt.dset_name,
197
+ domain=domain,
198
+ data_path=opt.train_path,
199
+ v_feat_types=opt.v_feat_types,
200
+ v_feat_dirs=opt.v_feat_dirs,
201
+ t_feat_dir=opt.t_feat_dir,
202
+ use_tef=True
203
+ )
204
+ dataloader = DatasetHL(**dataset_config)
205
+
206
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
207
+ count_parameters(model)
208
+ logger.info(f"Start Training {domain}")
209
+ best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
210
+ scores[domain] = best_score
211
+ scores['AVG'] = sum(scores.values()) / len(scores)
212
+
213
+ # save the final results.
214
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
215
+ save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
216
+
217
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
218
+ tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
219
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
220
+ tb_writer.close()
221
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
222
+
223
+ print(opt.dset_name)
224
+ print(scores)
225
+ return
226
+
227
+ if __name__ == '__main__':
228
+ start_training()
229
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
main/inference_mr.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import pprint
3
+ from tqdm import tqdm, trange
4
+ import numpy as np
5
+ import os
6
+ from collections import OrderedDict, defaultdict
7
+ from utils.basic_utils import AverageMeter
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils.data import DataLoader
13
+
14
+ from main.config import TestOptions, setup_model
15
+ from main.dataset import DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
16
+ from eval.eval import eval_submission
17
+ from eval.postprocessing import PostProcessorDETR
18
+ from utils.basic_utils import save_jsonl, save_json
19
+ from utils.temporal_nms import temporal_nms
20
+ from utils.span_utils import span_cxw_to_xx
21
+
22
+ import logging
23
+ import importlib
24
+
25
+ logger = logging.getLogger(__name__)
26
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
27
+ datefmt="%Y-%m-%d %H:%M:%S",
28
+ level=logging.INFO)
29
+
30
+
31
+ def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms):
32
+ mr_res_after_nms = []
33
+ for e in mr_res:
34
+ e["pred_relevant_windows"] = temporal_nms(
35
+ e["pred_relevant_windows"][:max_before_nms],
36
+ nms_thd=nms_thd,
37
+ max_after_nms=max_after_nms
38
+ )
39
+ mr_res_after_nms.append(e)
40
+ return mr_res_after_nms
41
+
42
+
43
+ def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename):
44
+ # IOU_THDS = (0.5, 0.7)
45
+ logger.info("Saving/Evaluating before nms results")
46
+ submission_path = os.path.join(opt.results_dir, save_submission_filename)
47
+ save_jsonl(submission, submission_path)
48
+
49
+ if opt.eval_split_name in ["val", "test"]: # since test_public has no GT
50
+ metrics = eval_submission(
51
+ submission, gt_data,
52
+ verbose=opt.debug, match_number=not opt.debug,
53
+ )
54
+ save_metrics_path = submission_path.replace(".jsonl", "_metrics.json")
55
+ save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
56
+ latest_file_paths = [submission_path, save_metrics_path]
57
+ else:
58
+ metrics = None
59
+ latest_file_paths = [submission_path, ]
60
+
61
+ if opt.nms_thd != -1:
62
+ logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd))
63
+ submission_after_nms = post_processing_mr_nms(
64
+ submission, nms_thd=opt.nms_thd,
65
+ max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms
66
+ )
67
+
68
+ logger.info("Saving/Evaluating nms results")
69
+ submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd))
70
+ save_jsonl(submission_after_nms, submission_nms_path)
71
+ if opt.eval_split_name == "val":
72
+ metrics_nms = eval_submission(
73
+ submission_after_nms, gt_data,
74
+ verbose=opt.debug, match_number=not opt.debug
75
+ )
76
+ save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json")
77
+ save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
78
+ latest_file_paths += [submission_nms_path, save_metrics_nms_path]
79
+ else:
80
+ metrics_nms = None
81
+ latest_file_paths = [submission_nms_path, ]
82
+ else:
83
+ metrics_nms = None
84
+ return metrics, metrics_nms, latest_file_paths
85
+
86
+
87
+ @torch.no_grad()
88
+ def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None):
89
+ model.eval()
90
+ if criterion:
91
+ assert eval_loader.dataset.load_labels
92
+ criterion.eval()
93
+
94
+ loss_meters = defaultdict(AverageMeter)
95
+ write_tb = tb_writer is not None and epoch_i is not None
96
+
97
+ mr_res = []
98
+ for batch in tqdm(eval_loader, desc="compute st ed scores"):
99
+ query_meta = batch[0]
100
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
101
+ outputs = model(**model_inputs)
102
+ prob = outputs["pred_logits"] # the last channel may be 1 or 2.
103
+ # if opt.eval_mode == 'v1':
104
+ # prob = prob * outputs["saliency_scores"].unsqueeze(-1) # v1
105
+ # if opt.eval_mode == 'v2':
106
+ # prob = F.softmax(prob, dim=1) * outputs["saliency_scores"].unsqueeze(-1) # v2
107
+ # if opt.eval_mode == 'v3':
108
+ # prob = outputs["saliency_scores"].unsqueeze(-1)
109
+ if outputs["pred_logits"].shape[-1] > 1:
110
+ prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2)
111
+ if opt.span_loss_type == "l1":
112
+ scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it
113
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, 2)
114
+
115
+ if opt.model_id not in ['moment_detr']: # dense regression.
116
+ start_spans = targets['timestamp']
117
+ pred_spans = start_spans + pred_spans
118
+ mask = targets['timestamp_mask'].bool()
119
+ scores[~mask] = 0
120
+ # if opt.eval_mode == 'v4':
121
+ # _mask = targets['timestamp_window'].bool()
122
+ # scores[~_mask] = 0
123
+
124
+ if opt.eval_mode == 'add':
125
+ # pdb.set_trace()
126
+ _saliency_scores = outputs["saliency_scores"].half() + prob.squeeze(-1)
127
+ else:
128
+ _saliency_scores = outputs["saliency_scores"].half() # (bsz, L)
129
+
130
+ if opt.eval_mode == 'add_mr':
131
+ prob = outputs["saliency_scores"].half().unsqueeze(-1) + prob
132
+
133
+ saliency_scores = []
134
+ valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist()
135
+ for j in range(len(valid_vid_lengths)):
136
+ saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist())
137
+ else:
138
+ bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2)
139
+ pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l)
140
+ # TODO use more advanced decoding method with st_ed product
141
+ pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2)
142
+ scores = torch.prod(pred_span_scores, 2) # (bsz, #queries)
143
+ pred_spans[:, 1] += 1
144
+ pred_spans *= opt.clip_length
145
+
146
+ # compose predictions
147
+ for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())):
148
+ if opt.span_loss_type == "l1":
149
+ if opt.model_id in ['moment_detr']:
150
+ spans = span_cxw_to_xx(spans) * meta["duration"]
151
+ else:
152
+ spans = spans * meta["duration"]
153
+ spans = torch.clamp(spans, 0, meta["duration"]) # added by Kevin, since window cannot be longer than video duration.
154
+
155
+ # (#queries, 3), [st(float), ed(float), score(float)]
156
+ cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist()
157
+ if not opt.no_sort_results:
158
+ cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True)
159
+ cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds]
160
+ cur_query_pred = dict(
161
+ qid=meta["qid"],
162
+ query=meta["query"],
163
+ vid=meta["vid"],
164
+ pred_relevant_windows=cur_ranked_preds,
165
+ pred_saliency_scores=saliency_scores[idx]
166
+ )
167
+ mr_res.append(cur_query_pred)
168
+
169
+ if criterion:
170
+ loss_dict = criterion(outputs, targets)
171
+ weight_dict = criterion.weight_dict
172
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
173
+ loss_dict["loss_overall"] = float(losses) # for logging only
174
+ for k, v in loss_dict.items():
175
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
176
+
177
+ if opt.debug:
178
+ break
179
+
180
+ if write_tb and criterion:
181
+ for k, v in loss_meters.items():
182
+ tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1)
183
+
184
+ post_processor = PostProcessorDETR(
185
+ clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150,
186
+ min_w_l=2, max_w_l=150, move_window_method="left",
187
+ # process_func_names=("clip_ts", "round_multiple")
188
+ process_func_names=["round_multiple"] # have added `clamp' op on line 147, thus we do not need `clip_ts' again;
189
+ )
190
+ # todo: are we need round_multiple?
191
+ if opt.round_multiple > 0:
192
+ mr_res = post_processor(mr_res)
193
+ return mr_res, loss_meters
194
+
195
+ def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer):
196
+ """compute and save query and video proposal embeddings"""
197
+ eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict)
198
+ return eval_res, eval_loss_meters
199
+
200
+ def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None):
201
+ logger.info("Generate submissions")
202
+ model.eval()
203
+ if criterion is not None and eval_dataset.load_labels:
204
+ criterion.eval()
205
+ else:
206
+ criterion = None
207
+
208
+ eval_loader = DataLoader(
209
+ eval_dataset,
210
+ collate_fn=start_end_collate_mr,
211
+ batch_size=opt.eval_bsz,
212
+ num_workers=opt.num_workers,
213
+ shuffle=False,
214
+ pin_memory=opt.pin_memory
215
+ )
216
+
217
+ submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer)
218
+ if opt.no_sort_results:
219
+ save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl")
220
+ metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing(
221
+ submission, opt, eval_dataset.data, save_submission_filename)
222
+ return metrics, metrics_nms, eval_loss_meters, latest_file_paths
223
+
224
+ def start_inference():
225
+ logger.info("Setup config, data and model...")
226
+ opt = TestOptions().parse()
227
+ # pdb.set_trace()
228
+ cudnn.benchmark = True
229
+ cudnn.deterministic = False
230
+
231
+ assert opt.eval_path is not None
232
+ eval_dataset = DatasetMR(
233
+ dset_name=opt.dset_name,
234
+ data_path=opt.eval_path,
235
+ v_feat_dirs=opt.v_feat_dirs,
236
+ q_feat_dir=opt.t_feat_dir,
237
+ v_feat_dim=opt.v_feat_dim,
238
+ q_feat_dim=opt.t_feat_dim,
239
+ q_feat_type="last_hidden_state",
240
+ max_q_l=opt.max_q_l,
241
+ max_v_l=opt.max_v_l,
242
+ ctx_mode=opt.ctx_mode,
243
+ data_ratio=opt.data_ratio,
244
+ normalize_v=not opt.no_norm_vfeat,
245
+ normalize_t=not opt.no_norm_tfeat,
246
+ clip_len=opt.clip_length,
247
+ max_windows=opt.max_windows,
248
+ load_labels=True, # opt.eval_split_name == "val",
249
+ span_loss_type=opt.span_loss_type,
250
+ txt_drop_ratio=0,
251
+ use_cache=opt.use_cache,
252
+ )
253
+
254
+ if opt.lr_warmup > 0:
255
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
256
+ total_steps = opt.n_epoch
257
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
258
+ opt.lr_warmup = [warmup_steps, total_steps]
259
+
260
+ model, criterion, _, _ = setup_model(opt)
261
+ save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format(
262
+ opt.dset_name, opt.eval_split_name, opt.eval_id)
263
+ logger.info("Starting inference...")
264
+ with torch.no_grad():
265
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
266
+ eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion)
267
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
268
+ if metrics_nms is not None:
269
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
270
+
271
+
272
+ if __name__ == '__main__':
273
+ start_inference()
main/inference_qfvs.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import json
5
+ import pprint
6
+ import random
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import h5py
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.backends.cudnn as cudnn
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import sys
20
+ sys.path.append('/Users/kevin/univtg')
21
+ from main.config import BaseOptions, setup_model
22
+ from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
23
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
24
+ from utils.model_utils import count_parameters
25
+ from eval.qfvs import calculate_semantic_matching, load_videos_tag
26
+
27
+ import logging
28
+ logger = logging.getLogger(__name__)
29
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ level=logging.INFO)
32
+
33
+ def eval_epoch(model, config, opt):
34
+ model.eval()
35
+ f1_sum = 0; p_sum = 0; r_sum = 0
36
+
37
+ assert len(config['test_videos']) == 1
38
+ video_id = config['test_videos'][0]
39
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
40
+
41
+ feat_type = config['vid_feature']
42
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
43
+ features = torch.from_numpy(feat['features'][()])
44
+ seg_len = torch.from_numpy(feat['seg_len'][()])
45
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
46
+
47
+ # dim = features.shape[-1]
48
+ # ctx_l = seg_len.sum().cpu()
49
+
50
+ # dim = features.shape[-1]
51
+ # ctx_l = features.shape[1]
52
+ # seg_len = torch.ones(ctx_l)
53
+ # features = features.reshape(-1, dim)[:ctx_l]
54
+
55
+ # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
56
+ # tef_ed = tef_st + 1.0 / ctx_l
57
+ # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
58
+ # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
59
+
60
+ transfer = {"Cupglass": "Glass",
61
+ "Musicalinstrument": "Instrument",
62
+ "Petsanimal": "Animal"}
63
+
64
+ with open(os.path.join('./plot', opt.dset_name, str(opt.qfvs_split) +'.jsonl'), 'w') as f_write:
65
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
66
+ evaluation_num=len(files)
67
+
68
+ mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
69
+ for j in range(len(seg_len)):
70
+ for k in range(seg_len[j]):
71
+ mask_GT[j][k] = 1
72
+
73
+ for file in files:
74
+ summaries_GT=[]
75
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
76
+ for line in f.readlines():
77
+ summaries_GT.append(int(line.strip()))
78
+
79
+ concept1, concept2 = file.split('_')[0:2]
80
+
81
+ ##############
82
+ if concept1 in transfer:
83
+ concept1 = transfer[concept1]
84
+ if concept2 in transfer:
85
+ concept2 = transfer[concept2]
86
+ concept1 = embedding[concept1]
87
+ concept2 = embedding[concept2]
88
+
89
+ concept1 = l2_normalize_np_array(concept1)
90
+ concept2 = l2_normalize_np_array(concept2)
91
+
92
+ data = {
93
+ 'features':features,
94
+ 'seg_len': seg_len,
95
+ 'tokens_pad1':torch.from_numpy(concept1),
96
+ 'tokens_pad2':torch.from_numpy(concept2),
97
+ 'mask_GT': mask_GT
98
+ }
99
+
100
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
101
+
102
+ summaries_GT = [x - 1 for x in summaries_GT]
103
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
104
+
105
+ if opt.f_loss_coef == 0:
106
+ output_type = 'saliency_scores'
107
+ elif opt.s_loss_intra_coef == 0:
108
+ output_type = 'pred_logits'
109
+ else:
110
+ if config['qfvs_score_ensemble'] > 0:
111
+ output_type = ['pred_logits', 'saliency_scores']
112
+ else:
113
+ output_type = 'pred_logits'
114
+
115
+ with torch.no_grad():
116
+ if not isinstance(output_type, list):
117
+ score1 = model(**input1)[output_type].squeeze()
118
+ score1 = score1.masked_select(mask_GT)
119
+
120
+ score2 = model(**input2)[output_type].squeeze()
121
+ score2 = score2.masked_select(mask_GT)
122
+
123
+ score = model(**input_oracle)[output_type].squeeze()
124
+ score = score.masked_select(mask_GT)
125
+ else:
126
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
127
+ for output_t in output_type:
128
+ score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
129
+ score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
130
+ score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
131
+
132
+ if config['qfvs_score_gather'] > 0:
133
+ score = score + score1 + score2
134
+ else:
135
+ score = score
136
+
137
+ # since video4 features dim is greater than video_shots_tag.
138
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
139
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
140
+
141
+ c1, c2 = file.split('_')[0:2]
142
+ if c1 in transfer:
143
+ c1 = transfer[c1]
144
+ if c2 in transfer:
145
+ c2 = transfer[c2]
146
+
147
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
148
+ entry = {'concept1': c1, 'concept2': c2,
149
+ 'score':score.tolist(),
150
+ 'top_percent': config["top_percent"],
151
+ 'top_pred':top_index.tolist(),
152
+ 'gt':summaries_GT,
153
+ 'p': p, 'r': r, 'f1': f1,
154
+ 'shots': video_shots_tag[video_id-1].shape[0]}
155
+ f_write.write(json.dumps(entry) + '\n')
156
+ f1_sum+=f1; r_sum+=r; p_sum+=p
157
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
158
+ 'R': round(100* r_sum/evaluation_num,2) ,
159
+ 'P': round(100* p_sum/evaluation_num,2) }
160
+
161
+ def idx2time(idx):
162
+ sec1, sec2 = idx*5, (idx+1)*5
163
+
164
+ h1 = sec1 // 3600
165
+ m1 = (sec1 - h1*3600) // 60
166
+ s1 = sec1 % 60
167
+
168
+ h2 = sec2 // 3600
169
+ m2 = (sec2 - h2*3600) // 60
170
+ s2 = sec2 % 60
171
+ print(h1,m1,s1,'\t', h2,m2,s2)
172
+
173
+ def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
174
+ model.train()
175
+ criterion.train()
176
+
177
+ # init meters
178
+ time_meters = defaultdict(AverageMeter)
179
+ loss_meters = defaultdict(AverageMeter)
180
+
181
+ timer_dataloading = time.time()
182
+ loss_total = 0
183
+
184
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
185
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
186
+ timer_start = time.time()
187
+ model_input1, model_input2, model_input_oracle, \
188
+ model_gt1, model_gt2, model_gt_oracle, \
189
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
190
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
191
+
192
+ timer_start = time.time()
193
+ output1 = model(**model_input1)
194
+ output2 = model(**model_input2)
195
+ output_oracle = model(**model_input_oracle)
196
+
197
+ loss_dict = {}
198
+ loss_dict1 = criterion(output1, model_gt1, mask_GT)
199
+ loss_dict2 = criterion(output2, model_gt2, mask_GT)
200
+ loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
201
+
202
+ weight_dict = criterion.weight_dict
203
+ if config['qfvs_loss_gather'] > 0:
204
+ for k in loss_dict1.keys():
205
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
206
+ else:
207
+ loss_dict = loss_dict3
208
+
209
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
210
+ loss_total += losses.item()
211
+
212
+ time_meters["model_forward_time"].update(time.time() - timer_start)
213
+ timer_start = time.time()
214
+ optimizer.zero_grad()
215
+ losses.backward()
216
+ if opt.grad_clip > 0:
217
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
218
+ optimizer.step()
219
+ time_meters["model_backward_time"].update(time.time() - timer_start)
220
+
221
+ timer_dataloading = time.time()
222
+ return round(loss_total / len(train_loader), 2)
223
+
224
+ # train in single domain.
225
+ def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
226
+ # if opt.device.type == "cuda":
227
+ # logger.info("CUDA enabled.")
228
+ # model.to(opt.device)
229
+
230
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
231
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
232
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
233
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
234
+
235
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
236
+ if opt.start_epoch is None:
237
+ start_epoch = -1 if opt.eval_init else 0
238
+ else:
239
+ start_epoch = opt.start_epoch
240
+
241
+ val_score = eval_epoch(model, config, opt)
242
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
243
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
244
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
245
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
246
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
247
+ if epoch_i > -1:
248
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
249
+ lr_scheduler.step()
250
+ eval_epoch_interval = opt.eval_epoch
251
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
252
+ with torch.no_grad():
253
+ val_score = eval_epoch(model, config, opt)
254
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
255
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
256
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
257
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
258
+
259
+ if prev_best_score['Fscore'] < val_score['F']:
260
+ prev_best_score['Fscore'] = val_score['F']
261
+ prev_best_score['Precision'] = val_score['P']
262
+ prev_best_score['Recall'] = val_score['R']
263
+
264
+ checkpoint = {
265
+ "model": model.state_dict(),
266
+ "optimizer": optimizer.state_dict(),
267
+ "epoch": epoch_i,
268
+ "opt": opt
269
+ }
270
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
271
+ tb_writer.close()
272
+ return prev_best_score
273
+
274
+ def update_config(opt, config):
275
+ # for key in ["max_segment_num", "max_frame_num", "top_percent",
276
+ # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
277
+ # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
278
+ config["max_segment_num"] = opt.max_segment_num
279
+ config["max_frame_num"] = opt.max_frame_num
280
+ config["top_percent"] = opt.top_percent
281
+ config["vid_feature"] = opt.qfvs_vid_feature
282
+ config["txt_feature"] = opt.qfvs_txt_feature
283
+ config["qfvs_dense_shot"] = opt.qfvs_dense_shot
284
+ config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
285
+ config["qfvs_score_gather"] = opt.qfvs_score_gather
286
+ config["qfvs_loss_gather"] = opt.qfvs_loss_gather
287
+ return config
288
+
289
+ def start_training():
290
+ logger.info("Setup config, data and model...")
291
+ opt = BaseOptions().parse()
292
+ set_seed(opt.seed)
293
+
294
+ # config = load_json("./main/config_qfvs.json")
295
+ config = {}
296
+ config = update_config(opt, config)
297
+
298
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
299
+
300
+ # key -> test video; value -> training videos.
301
+ qfvs_split = {
302
+ 1: [2, 3, 4],
303
+ 2: [1, 3, 4],
304
+ 3: [1, 2, 4],
305
+ 4: [1, 2, 3]
306
+ }
307
+
308
+ scores_videos = {}
309
+ for test_id, splits in qfvs_split.items():
310
+ if opt.qfvs_split != -1:
311
+ if test_id != opt.qfvs_split:
312
+ continue
313
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
314
+ config['train_videos'] = qfvs_split[test_id]
315
+ config['test_videos'] = [test_id]
316
+ train_dataset = DatasetQFVS(config)
317
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
318
+
319
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
320
+ count_parameters(model)
321
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
322
+ scores_videos['V'+str(test_id)] = best_score
323
+
324
+ # save the final results.
325
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
326
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
327
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
328
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
329
+
330
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
331
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
332
+
333
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
334
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
335
+ tb_writer.close()
336
+
337
+ print(scores_videos)
338
+ return
339
+
340
+ if __name__ == '__main__':
341
+ start_training()
342
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
main/train_hl.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import json
5
+ import pprint
6
+ import random
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.backends.cudnn as cudnn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ import sys
19
+ sys.path.append('/data/home/qinghonglin/univtg')
20
+ from main.config import BaseOptions, setup_model
21
+ from main.dataset import DatasetHL, prepare_batch_inputs_hl, start_end_collate_hl
22
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl
23
+ from utils.model_utils import count_parameters
24
+
25
+ import logging
26
+ logger = logging.getLogger(__name__)
27
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
28
+ datefmt="%Y-%m-%d %H:%M:%S",
29
+ level=logging.INFO)
30
+
31
+ def eval_epoch(model, train_val_dataset, opt): #, nms_thresh, device):
32
+ model.eval()
33
+
34
+ scores = []
35
+ train_val_dataset.set_state('val')
36
+ val_loader = DataLoader(
37
+ train_val_dataset,
38
+ collate_fn=start_end_collate_hl,
39
+ batch_size=opt.eval_bsz,
40
+ num_workers=opt.num_workers,
41
+ shuffle=False,
42
+ pin_memory=opt.pin_memory
43
+ )
44
+
45
+ with torch.no_grad():
46
+ for data in val_loader:
47
+ model_inputs, targets = prepare_batch_inputs_hl(data)
48
+ outputs = model(**model_inputs)
49
+ # pred_cls = outputs['pred_logits'].squeeze(-1)
50
+ # pred_cls = outputs['saliency_scores']
51
+ # pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
52
+
53
+ # pdb.set_trace()
54
+ if opt.f_loss_coef == 0:
55
+ pred_cls = outputs['saliency_scores']
56
+ elif opt.s_loss_intra_coef == 0:
57
+ pred_cls = outputs['pred_logits'].squeeze(-1)
58
+ else:
59
+ if opt.eval_mode == 'add':
60
+ pred_cls = outputs['saliency_scores'] + outputs['pred_logits'].squeeze(-1)
61
+ else:
62
+ pred_cls = outputs['pred_logits'].squeeze(-1)
63
+
64
+ pred_cls = pred_cls.detach().cpu()
65
+ scores.append(pred_cls)
66
+ map = round(train_val_dataset.evaluate(scores)['mAP'] * 100, 4)
67
+ return map
68
+
69
+ def train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer):
70
+ logger.info(f"[Epoch {epoch_i+1}]")
71
+ model.train()
72
+ criterion.train()
73
+
74
+ train_val_dataset.set_state('train')
75
+ train_loader = DataLoader(
76
+ train_val_dataset,
77
+ collate_fn=start_end_collate_hl,
78
+ batch_size=opt.bsz,
79
+ num_workers=opt.num_workers,
80
+ shuffle=True,
81
+ pin_memory=opt.pin_memory
82
+ )
83
+
84
+ # init meters
85
+ time_meters = defaultdict(AverageMeter)
86
+ loss_meters = defaultdict(AverageMeter)
87
+
88
+ num_training_examples = len(train_loader)
89
+ timer_dataloading = time.time()
90
+ for batch_idx, batch in enumerate(train_loader):
91
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
92
+ timer_start = time.time()
93
+ model_inputs, targets = prepare_batch_inputs_hl(batch)
94
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
95
+
96
+ timer_start = time.time()
97
+ outputs = model(**model_inputs)
98
+ loss_dict = criterion(outputs, targets)
99
+ weight_dict = criterion.weight_dict
100
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
101
+ time_meters["model_forward_time"].update(time.time() - timer_start)
102
+
103
+ timer_start = time.time()
104
+ optimizer.zero_grad()
105
+ losses.backward()
106
+ if opt.grad_clip > 0:
107
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
108
+ optimizer.step()
109
+ time_meters["model_backward_time"].update(time.time() - timer_start)
110
+
111
+ loss_dict["loss_overall"] = float(losses)
112
+ for k, v in loss_dict.items():
113
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
114
+
115
+ timer_dataloading = time.time()
116
+ if opt.debug and batch_idx == 3:
117
+ break
118
+
119
+ # print/add logs
120
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
121
+ for k, v in loss_meters.items():
122
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
123
+
124
+ to_write = opt.train_log_txt_formatter.format(
125
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
126
+ epoch=epoch_i+1,
127
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
128
+ with open(opt.train_log_filepath, "a") as f:
129
+ f.write(to_write)
130
+
131
+ logger.info("Epoch time stats:")
132
+ for name, meter in time_meters.items():
133
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
134
+ logger.info(f"{name} ==> {d}")
135
+
136
+ # train in single domain.
137
+ def train(model, criterion, optimizer, lr_scheduler, train_val_dataset, opt):
138
+ # if opt.device.type == "cuda":
139
+ # logger.info("CUDA enabled.")
140
+ # model.to(opt.device)
141
+
142
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
143
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
144
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
145
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
146
+
147
+ prev_best_score = 0.
148
+ if opt.start_epoch is None:
149
+ start_epoch = -1 if opt.eval_init else 0
150
+ else:
151
+ start_epoch = opt.start_epoch
152
+
153
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
154
+ if epoch_i > -1:
155
+ train_epoch(model, criterion, train_val_dataset, optimizer, opt, epoch_i, tb_writer)
156
+ lr_scheduler.step()
157
+ eval_epoch_interval = opt.eval_epoch
158
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
159
+ with torch.no_grad():
160
+ scores = eval_epoch(model, train_val_dataset, opt)
161
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-{train_val_dataset.domain}-mAP", float(scores), epoch_i+1)
162
+ if prev_best_score < scores:
163
+ prev_best_score = scores
164
+ checkpoint = {
165
+ "model": model.state_dict(),
166
+ "optimizer": optimizer.state_dict(),
167
+ "epoch": epoch_i,
168
+ "opt": opt
169
+ }
170
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_{train_val_dataset.domain}_best.ckpt"))
171
+ tb_writer.close()
172
+ return prev_best_score
173
+
174
+ def start_training():
175
+ logger.info("Setup config, data and model...")
176
+ opt = BaseOptions().parse()
177
+ set_seed(opt.seed)
178
+
179
+ from main.config_hl import TVSUM_SPLITS, YOUTUBE_SPLITS
180
+ if opt.dset_name == "tvsum":
181
+ domain_splits = TVSUM_SPLITS.keys()
182
+ if opt.dset_name == "youtube":
183
+ domain_splits = YOUTUBE_SPLITS.keys()
184
+
185
+ scores = {}
186
+ if opt.lr_warmup > 0:
187
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
188
+ total_steps = opt.n_epoch
189
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
190
+ opt.lr_warmup = [warmup_steps, total_steps]
191
+
192
+ domain_splits = domain_splits if not opt.domain_name else [opt.domain_name]
193
+
194
+ for domain in domain_splits:
195
+ dataset_config = dict(
196
+ dset_name=opt.dset_name,
197
+ domain=domain,
198
+ data_path=opt.train_path,
199
+ v_feat_types=opt.v_feat_types,
200
+ v_feat_dirs=opt.v_feat_dirs,
201
+ t_feat_dir=opt.t_feat_dir,
202
+ use_tef=True
203
+ )
204
+ dataloader = DatasetHL(**dataset_config)
205
+
206
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
207
+ count_parameters(model)
208
+ logger.info(f"Start Training {domain}")
209
+ best_score = train(model, criterion, optimizer, lr_scheduler, dataloader, opt)
210
+ scores[domain] = best_score
211
+ scores['AVG'] = sum(scores.values()) / len(scores)
212
+
213
+ # save the final results.
214
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
215
+ save_json(scores, save_metrics_path, save_pretty=True, sort_keys=False)
216
+
217
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
218
+ tb_writer.add_text(f"HL-{opt.dset_name}", dict_to_markdown(scores, max_str_len=None))
219
+ tb_writer.add_scalar(f"Eval/HL-{opt.dset_name}-avg-mAP-key", float(scores['AVG']), 1)
220
+ tb_writer.close()
221
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
222
+
223
+ print(opt.dset_name)
224
+ print(scores)
225
+ return
226
+
227
+ if __name__ == '__main__':
228
+ start_training()
229
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
main/train_mr.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import sys
4
+ import time
5
+ import json
6
+ import pprint
7
+ import random
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.backends.cudnn as cudnn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ sys.path.append('/data/home/qinghonglin/univtg')
19
+ from main.config import BaseOptions, setup_model
20
+ from main.dataset import \
21
+ DatasetMR, start_end_collate_mr, prepare_batch_inputs_mr
22
+ from main.inference_mr import eval_epoch, start_inference
23
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
24
+ from utils.model_utils import count_parameters
25
+
26
+ import logging
27
+ logger = logging.getLogger(__name__)
28
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ level=logging.INFO)
31
+
32
+ def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
33
+ logger.info(f"[Epoch {epoch_i+1}]")
34
+ model.train()
35
+ criterion.train()
36
+
37
+ # init meters
38
+ time_meters = defaultdict(AverageMeter)
39
+ loss_meters = defaultdict(AverageMeter)
40
+
41
+ num_training_examples = len(train_loader)
42
+ timer_dataloading = time.time()
43
+ for batch_idx, batch in tqdm(enumerate(train_loader),
44
+ desc="Training Iteration",
45
+ total=num_training_examples):
46
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
47
+
48
+ timer_start = time.time()
49
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
50
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
51
+
52
+ timer_start = time.time()
53
+
54
+ # try:
55
+ outputs = model(**model_inputs)
56
+ loss_dict = criterion(outputs, targets)
57
+ weight_dict = criterion.weight_dict
58
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
59
+ time_meters["model_forward_time"].update(time.time() - timer_start)
60
+
61
+ timer_start = time.time()
62
+ optimizer.zero_grad()
63
+ losses.backward()
64
+
65
+ if opt.grad_clip > 0:
66
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
67
+ optimizer.step()
68
+ time_meters["model_backward_time"].update(time.time() - timer_start)
69
+
70
+ loss_dict["loss_overall"] = float(losses) # for logging only
71
+ for k, v in loss_dict.items():
72
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
73
+
74
+ timer_dataloading = time.time()
75
+
76
+ # print/add logs
77
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
78
+ for k, v in loss_meters.items():
79
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
80
+
81
+ to_write = opt.train_log_txt_formatter.format(
82
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
83
+ epoch=epoch_i+1,
84
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
85
+ with open(opt.train_log_filepath, "a") as f:
86
+ f.write(to_write)
87
+
88
+ logger.info("Epoch time stats:")
89
+ for name, meter in time_meters.items():
90
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
91
+ logger.info(f"{name} ==> {d}")
92
+
93
+
94
+ def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
95
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
96
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
97
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
98
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
99
+
100
+ train_loader = DataLoader(
101
+ train_dataset,
102
+ collate_fn=start_end_collate_mr,
103
+ batch_size=opt.bsz,
104
+ num_workers=opt.num_workers,
105
+ shuffle=True,
106
+ pin_memory=opt.pin_memory
107
+ )
108
+
109
+ prev_best_score = 0.
110
+ es_cnt = 0
111
+ if opt.start_epoch is None:
112
+ start_epoch = -1 if opt.eval_init else 0
113
+ else:
114
+ start_epoch = opt.start_epoch
115
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
116
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
117
+ if epoch_i > -1:
118
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
119
+ lr_scheduler.step()
120
+ eval_epoch_interval = opt.eval_epoch
121
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
122
+ with torch.no_grad():
123
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
124
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
125
+
126
+ # log
127
+ to_write = opt.eval_log_txt_formatter.format(
128
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
129
+ epoch=epoch_i,
130
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
131
+ eval_metrics_str=json.dumps(metrics_no_nms))
132
+
133
+ with open(opt.eval_log_filepath, "a") as f:
134
+ f.write(to_write)
135
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
136
+ if metrics_nms is not None:
137
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
138
+
139
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
140
+ for k, v in metrics["brief"].items():
141
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
142
+
143
+ # stop_score = metrics["brief"]["MR-full-mAP"]
144
+ # pdb.set_trace()
145
+ stop_score = metrics["brief"][opt.main_metric]
146
+ if stop_score > prev_best_score:
147
+ es_cnt = 0
148
+ prev_best_score = stop_score
149
+
150
+ checkpoint = {
151
+ "model": model.state_dict(),
152
+ "optimizer": optimizer.state_dict(),
153
+ "lr_scheduler": lr_scheduler.state_dict(),
154
+ "epoch": epoch_i,
155
+ "opt": opt
156
+ }
157
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
158
+
159
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
160
+ for src, tgt in zip(latest_file_paths, best_file_paths):
161
+ os.renames(src, tgt)
162
+ logger.info("The checkpoint file has been updated.")
163
+ else:
164
+ es_cnt += 1
165
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
166
+ with open(opt.train_log_filepath, "a") as f:
167
+ f.write(f"Early Stop at epoch {epoch_i}")
168
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
169
+ break
170
+
171
+ # save ckpt
172
+ checkpoint = {
173
+ "model": model.state_dict(),
174
+ "optimizer": optimizer.state_dict(),
175
+ "lr_scheduler": lr_scheduler.state_dict(),
176
+ "epoch": epoch_i,
177
+ "opt": opt
178
+ }
179
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
180
+
181
+ if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
182
+ checkpoint = {
183
+ "model": model.state_dict(),
184
+ "optimizer": optimizer.state_dict(),
185
+ "epoch": epoch_i,
186
+ "opt": opt
187
+ }
188
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
189
+
190
+ if opt.debug:
191
+ break
192
+
193
+ tb_writer.close()
194
+
195
+
196
+ def start_training():
197
+ logger.info("Setup config, data and model...")
198
+ opt = BaseOptions().parse()
199
+ set_seed(opt.seed)
200
+ if opt.debug: # keep the model run deterministically
201
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
202
+ # Enable this only when input size is fixed.
203
+ cudnn.benchmark = False
204
+ cudnn.deterministic = True
205
+
206
+ dataset_config = dict(
207
+ dset_name=opt.dset_name,
208
+ data_path=opt.train_path,
209
+ v_feat_dirs=opt.v_feat_dirs,
210
+ q_feat_dir=opt.t_feat_dir,
211
+ v_feat_dim=opt.v_feat_dim,
212
+ q_feat_dim=opt.t_feat_dim,
213
+ q_feat_type="last_hidden_state",
214
+ max_q_l=opt.max_q_l,
215
+ max_v_l=opt.max_v_l,
216
+ ctx_mode=opt.ctx_mode,
217
+ data_ratio=opt.data_ratio,
218
+ normalize_v=not opt.no_norm_vfeat,
219
+ normalize_t=not opt.no_norm_tfeat,
220
+ clip_len=opt.clip_length,
221
+ max_windows=opt.max_windows,
222
+ span_loss_type=opt.span_loss_type,
223
+ txt_drop_ratio=opt.txt_drop_ratio,
224
+ use_cache=opt.use_cache,
225
+ add_easy_negative=opt.add_easy_negative,
226
+ easy_negative_only=opt.easy_negative_only
227
+ )
228
+
229
+ dataset_config["data_path"] = opt.train_path
230
+ train_dataset = DatasetMR(**dataset_config)
231
+
232
+ if opt.eval_path is not None:
233
+ dataset_config["data_path"] = opt.eval_path
234
+ dataset_config["txt_drop_ratio"] = 0
235
+ dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
236
+ # dataset_config["load_labels"] = False # uncomment to calculate eval loss
237
+ eval_dataset = DatasetMR(**dataset_config)
238
+ else:
239
+ eval_dataset = None
240
+
241
+ if opt.lr_warmup > 0:
242
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
243
+ total_steps = opt.n_epoch
244
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
245
+ opt.lr_warmup = [warmup_steps, total_steps]
246
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
247
+ logger.info(f"Model {model}")
248
+ count_parameters(model)
249
+ logger.info("Start Training...")
250
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
251
+ return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
252
+
253
+
254
+ if __name__ == '__main__':
255
+ best_ckpt_path, eval_split_name, eval_path, debug = start_training()
256
+ if not debug:
257
+ input_args = ["--resume", best_ckpt_path,
258
+ "--eval_split_name", eval_split_name,
259
+ "--eval_path", eval_path]
260
+
261
+ import sys
262
+ sys.argv[1:] = input_args
263
+ logger.info("\n\n\nFINISHED TRAINING!!!")
264
+ logger.info("Evaluating model at {}".format(best_ckpt_path))
265
+ logger.info("Input args {}".format(sys.argv[1:]))
266
+ start_inference()
main/train_qfvs.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import time
4
+ import json
5
+ import pprint
6
+ import random
7
+ import importlib
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import h5py
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.backends.cudnn as cudnn
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import sys
20
+ sys.path.append('/Users/kevin/univtg')
21
+ from main.config import BaseOptions, setup_model
22
+ from main.dataset_qfvs import DatasetQFVS, prepare_batch_inputs_qfvs, start_end_collate_qfvs
23
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown, save_json, save_jsonl, load_json, load_pickle, l2_normalize_np_array
24
+ from utils.model_utils import count_parameters
25
+ from eval.qfvs import calculate_semantic_matching, load_videos_tag
26
+
27
+ import logging
28
+ logger = logging.getLogger(__name__)
29
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
30
+ datefmt="%Y-%m-%d %H:%M:%S",
31
+ level=logging.INFO)
32
+
33
+ def eval_epoch(model, config, opt):
34
+ model.eval()
35
+ f1_sum = 0; p_sum = 0; r_sum = 0
36
+
37
+ assert len(config['test_videos']) == 1
38
+ video_id = config['test_videos'][0]
39
+ embedding = load_pickle(f"./data/qfvs/txt_clip/{config['txt_feature']}.pkl")
40
+
41
+ feat_type = config['vid_feature']
42
+ feat = h5py.File(f'./data/qfvs/processed/P0{video_id}_{feat_type}.h5', 'r')
43
+ features = torch.from_numpy(feat['features'][()])
44
+ seg_len = torch.from_numpy(feat['seg_len'][()])
45
+ # seg_len = torch.tensor(feat['seg_len'][()]).unsqueeze(0).cuda()
46
+
47
+ # dim = features.shape[-1]
48
+ # ctx_l = seg_len.sum().cpu()
49
+
50
+ # dim = features.shape[-1]
51
+ # ctx_l = features.shape[1]
52
+ # seg_len = torch.ones(ctx_l)
53
+ # features = features.reshape(-1, dim)[:ctx_l]
54
+
55
+ # tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
56
+ # tef_ed = tef_st + 1.0 / ctx_l
57
+ # tef = torch.stack([tef_st, tef_ed], dim=1).cuda() # (Lv, 2)
58
+ # features = torch.cat([features, tef], dim=1) # (Lv, Dv+2)
59
+
60
+ transfer = {"Cupglass": "Glass",
61
+ "Musicalinstrument": "Instrument",
62
+ "Petsanimal": "Animal"}
63
+
64
+ for _,_,files in os.walk("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)):
65
+ evaluation_num=len(files)
66
+
67
+ mask_GT = torch.zeros(config["max_segment_num"], config["max_frame_num"], dtype=torch.bool).cuda()
68
+ for j in range(len(seg_len)):
69
+ for k in range(seg_len[j]):
70
+ mask_GT[j][k] = 1
71
+
72
+ for file in files:
73
+ summaries_GT=[]
74
+ with open("./data/qfvs/metadata/origin_data/Query-Focused_Summaries/Oracle_Summaries/P0"+str(video_id)+"/"+file,"r") as f:
75
+ for line in f.readlines():
76
+ summaries_GT.append(int(line.strip()))
77
+
78
+ concept1, concept2 = file.split('_')[0:2]
79
+
80
+ ##############
81
+ if concept1 in transfer:
82
+ concept1 = transfer[concept1]
83
+ if concept2 in transfer:
84
+ concept2 = transfer[concept2]
85
+ concept1 = embedding[concept1]
86
+ concept2 = embedding[concept2]
87
+
88
+ concept1 = l2_normalize_np_array(concept1)
89
+ concept2 = l2_normalize_np_array(concept2)
90
+
91
+ data = {
92
+ 'features':features,
93
+ 'seg_len': seg_len,
94
+ 'tokens_pad1':torch.from_numpy(concept1),
95
+ 'tokens_pad2':torch.from_numpy(concept2),
96
+ 'mask_GT': mask_GT
97
+ }
98
+
99
+ input1, input2, input_oracle, mask = prepare_batch_inputs_qfvs(start_end_collate_qfvs([data]), config, eval=True)
100
+
101
+ summaries_GT = [x - 1 for x in summaries_GT]
102
+ video_shots_tag = load_videos_tag(mat_path="./eval/Tags.mat")
103
+
104
+ if opt.f_loss_coef == 0:
105
+ output_type = 'saliency_scores'
106
+ elif opt.s_loss_intra_coef == 0:
107
+ output_type = 'pred_logits'
108
+ else:
109
+ if config['qfvs_score_ensemble'] > 0:
110
+ output_type = ['pred_logits', 'saliency_scores']
111
+ else:
112
+ output_type = 'pred_logits'
113
+
114
+ with torch.no_grad():
115
+ if not isinstance(output_type, list):
116
+ score1 = model(**input1)[output_type].squeeze()
117
+ score1 = score1.masked_select(mask_GT)
118
+
119
+ score2 = model(**input2)[output_type].squeeze()
120
+ score2 = score2.masked_select(mask_GT)
121
+
122
+ score = model(**input_oracle)[output_type].squeeze()
123
+ score = score.masked_select(mask_GT)
124
+ else:
125
+ score1, score2, score = torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda(), torch.zeros((int(mask.sum().item()))).cuda()
126
+ for output_t in output_type:
127
+ score1 += model(**input1)[output_t].squeeze().masked_select(mask_GT)
128
+ score2 += model(**input2)[output_t].squeeze().masked_select(mask_GT)
129
+ score += model(**input_oracle)[output_t].squeeze().masked_select(mask_GT)
130
+
131
+ if config['qfvs_score_gather'] > 0:
132
+ score = score + score1 + score2
133
+ else:
134
+ score = score
135
+
136
+ # since video4 features dim is greater than video_shots_tag.
137
+ score = score[:min(score.shape[0], video_shots_tag[video_id-1].shape[0])]
138
+ _, top_index = score.topk(int(score.shape[0] * config["top_percent"]))
139
+
140
+ p, r, f1 = calculate_semantic_matching(list(top_index.cpu().numpy()), summaries_GT, video_shots_tag, video_id=video_id-1)
141
+ f1_sum+=f1; r_sum+=r; p_sum+=p
142
+
143
+ return {'F': round(100* f1_sum/evaluation_num,2) ,
144
+ 'R': round(100* r_sum/evaluation_num,2) ,
145
+ 'P': round(100* p_sum/evaluation_num,2) }
146
+
147
+ def idx2time(idx):
148
+ sec1, sec2 = idx*5, (idx+1)*5
149
+
150
+ h1 = sec1 // 3600
151
+ m1 = (sec1 - h1*3600) // 60
152
+ s1 = sec1 % 60
153
+
154
+ h2 = sec2 // 3600
155
+ m2 = (sec2 - h2*3600) // 60
156
+ s2 = sec2 % 60
157
+ print(h1,m1,s1,'\t', h2,m2,s2)
158
+
159
+ def train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer):
160
+ model.train()
161
+ criterion.train()
162
+
163
+ # init meters
164
+ time_meters = defaultdict(AverageMeter)
165
+ loss_meters = defaultdict(AverageMeter)
166
+
167
+ timer_dataloading = time.time()
168
+ loss_total = 0
169
+
170
+ for batch_idx, batch in enumerate(tqdm(train_loader)):
171
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
172
+ timer_start = time.time()
173
+ model_input1, model_input2, model_input_oracle, \
174
+ model_gt1, model_gt2, model_gt_oracle, \
175
+ mask_GT = prepare_batch_inputs_qfvs(batch, config)
176
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
177
+
178
+ timer_start = time.time()
179
+ output1 = model(**model_input1)
180
+ output2 = model(**model_input2)
181
+ output_oracle = model(**model_input_oracle)
182
+
183
+ loss_dict = {}
184
+ loss_dict1 = criterion(output1, model_gt1, mask_GT)
185
+ loss_dict2 = criterion(output2, model_gt2, mask_GT)
186
+ loss_dict3 = criterion(output_oracle, model_gt_oracle, mask_GT)
187
+
188
+ weight_dict = criterion.weight_dict
189
+ if config['qfvs_loss_gather'] > 0:
190
+ for k in loss_dict1.keys():
191
+ loss_dict[k] = loss_dict1[k] + loss_dict2[k] + loss_dict3[k]
192
+ else:
193
+ loss_dict = loss_dict3
194
+
195
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
196
+ loss_total += losses.item()
197
+
198
+ time_meters["model_forward_time"].update(time.time() - timer_start)
199
+ timer_start = time.time()
200
+ optimizer.zero_grad()
201
+ losses.backward()
202
+ if opt.grad_clip > 0:
203
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
204
+ optimizer.step()
205
+ time_meters["model_backward_time"].update(time.time() - timer_start)
206
+
207
+ timer_dataloading = time.time()
208
+ return round(loss_total / len(train_loader), 2)
209
+
210
+ # train in single domain.
211
+ def train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config):
212
+ # if opt.device.type == "cuda":
213
+ # logger.info("CUDA enabled.")
214
+ # model.to(opt.device)
215
+
216
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
217
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
218
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
219
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
220
+
221
+ prev_best_score = {'Fscore':0, 'Precision':0, 'Recall':0}
222
+ if opt.start_epoch is None:
223
+ start_epoch = -1 if opt.eval_init else 0
224
+ else:
225
+ start_epoch = opt.start_epoch
226
+
227
+ val_score = eval_epoch(model, config, opt)
228
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), 0)
229
+ logger.info(f"[Epoch {0}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
230
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
231
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
232
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
233
+ if epoch_i > -1:
234
+ loss_epoch = train_epoch(model, criterion, train_loader, optimizer, opt, config, epoch_i, tb_writer)
235
+ lr_scheduler.step()
236
+ eval_epoch_interval = opt.eval_epoch
237
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
238
+ with torch.no_grad():
239
+ val_score = eval_epoch(model, config, opt)
240
+ tb_writer.add_scalar(f"Eval/QFVS-V{config['test_videos'][0]}-fscore", float(val_score['F']), epoch_i+1)
241
+ logger.info(f"[Epoch {epoch_i + 1}, Loss {loss_epoch}] [Fscore: {val_score['F']} / {prev_best_score['Fscore']}]"
242
+ f" [Precision: {val_score['P']} / {prev_best_score['Precision']}]"
243
+ f" [Recall: {val_score['R']} / {prev_best_score['Recall']}]")
244
+
245
+ if prev_best_score['Fscore'] < val_score['F']:
246
+ prev_best_score['Fscore'] = val_score['F']
247
+ prev_best_score['Precision'] = val_score['P']
248
+ prev_best_score['Recall'] = val_score['R']
249
+
250
+ checkpoint = {
251
+ "model": model.state_dict(),
252
+ "optimizer": optimizer.state_dict(),
253
+ "epoch": epoch_i,
254
+ "opt": opt
255
+ }
256
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_V{config['test_videos'][0]}_best.ckpt"))
257
+ tb_writer.close()
258
+ return prev_best_score
259
+
260
+ def update_config(opt, config):
261
+ # for key in ["max_segment_num", "max_frame_num", "top_percent",
262
+ # "qfvs_vid_feature", "qfvs_txt_feature", "qfvs_dense_shot",
263
+ # "qfvs_score_ensemble", "qfvs_score_gather", "qfvs_loss_gather"]:
264
+ config["max_segment_num"] = opt.max_segment_num
265
+ config["max_frame_num"] = opt.max_frame_num
266
+ config["top_percent"] = opt.top_percent
267
+ config["vid_feature"] = opt.qfvs_vid_feature
268
+ config["txt_feature"] = opt.qfvs_txt_feature
269
+ config["qfvs_dense_shot"] = opt.qfvs_dense_shot
270
+ config["qfvs_score_ensemble"] = opt.qfvs_score_ensemble
271
+ config["qfvs_score_gather"] = opt.qfvs_score_gather
272
+ config["qfvs_loss_gather"] = opt.qfvs_loss_gather
273
+ return config
274
+
275
+ def start_training():
276
+ logger.info("Setup config, data and model...")
277
+ opt = BaseOptions().parse()
278
+ set_seed(opt.seed)
279
+
280
+ # config = load_json("./main/config_qfvs.json")
281
+ config = {}
282
+ config = update_config(opt, config)
283
+
284
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
285
+
286
+ # key -> test video; value -> training videos.
287
+ qfvs_split = {
288
+ 1: [2, 3, 4],
289
+ 2: [1, 3, 4],
290
+ 3: [1, 2, 4],
291
+ 4: [1, 2, 3]
292
+ }
293
+
294
+ scores_videos = {}
295
+ for test_id, splits in qfvs_split.items():
296
+ logger.info(f"Start Training {opt.dset_name}: {test_id}")
297
+ config['train_videos'] = qfvs_split[test_id]
298
+ config['test_videos'] = [test_id]
299
+ train_dataset = DatasetQFVS(config)
300
+ train_loader = DataLoader(train_dataset, batch_size=opt.bsz, collate_fn=start_end_collate_qfvs, shuffle=True, num_workers=opt.num_workers)
301
+
302
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
303
+ count_parameters(model)
304
+ best_score = train(model, criterion, optimizer, lr_scheduler, train_loader, opt, config)
305
+ scores_videos['V'+str(test_id)] = best_score
306
+
307
+ # save the final results.
308
+ avg_fscore = sum([v['Fscore'] for k, v in scores_videos.items()]) / len(scores_videos)
309
+ avg_precision = sum([v['Precision'] for k, v in scores_videos.items()]) / len(scores_videos)
310
+ avg_recall = sum([v['Recall'] for k, v in scores_videos.items()]) / len(scores_videos)
311
+ scores_videos['avg'] = {'Fscore':avg_fscore, 'Precision':avg_precision, 'Recall':avg_recall}
312
+
313
+ save_metrics_path = os.path.join(opt.results_dir, f"best_{opt.dset_name}_{opt.eval_split_name}_preds_metrics.json")
314
+ save_json( scores_videos, save_metrics_path, save_pretty=True, sort_keys=False)
315
+
316
+ tb_writer.add_scalar(f"Eval/QFVS-avg-fscore", round(avg_fscore, 2), 1)
317
+ tb_writer.add_text(f"Eval/QFVS-{opt.dset_name}", dict_to_markdown(scores_videos, max_str_len=None))
318
+ tb_writer.close()
319
+
320
+ print(scores_videos)
321
+ return
322
+
323
+ if __name__ == '__main__':
324
+ start_training()
325
+ results = logger.info("\n\n\nFINISHED TRAINING!!!")
main/train_vlp.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import sys
4
+ import time
5
+ import json
6
+ import pprint
7
+ import random
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.backends.cudnn as cudnn
15
+ from torch.utils.data import DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ sys.path.append('/data/home/qinghonglin/univtg')
19
+ from main.config import BaseOptions, setup_model
20
+ from main.dataset import \
21
+ DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
22
+ from main.inference_mr import eval_epoch, start_inference
23
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
24
+ from utils.model_utils import count_parameters
25
+
26
+ import logging
27
+ logger = logging.getLogger(__name__)
28
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
29
+ datefmt="%Y-%m-%d %H:%M:%S",
30
+ level=logging.INFO)
31
+
32
+ def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls=None):
33
+ logger.info(f"[Epoch {epoch_i+1}]")
34
+ model.train()
35
+ criterion.train()
36
+
37
+ # init meters
38
+ time_meters = defaultdict(AverageMeter)
39
+ loss_meters = defaultdict(AverageMeter)
40
+
41
+ num_training_examples = len(train_loader)
42
+ timer_dataloading = time.time()
43
+ for batch_idx, batch in tqdm(enumerate(train_loader),
44
+ desc="Training Iteration",
45
+ total=num_training_examples):
46
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
47
+
48
+ timer_start = time.time()
49
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], opt.device, non_blocking=opt.pin_memory)
50
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
51
+
52
+ timer_start = time.time()
53
+
54
+ if cls is not None:
55
+ model_inputs.update(cls)
56
+
57
+ # try:
58
+ outputs = model(**model_inputs)
59
+ loss_dict = criterion(outputs, targets)
60
+ weight_dict = criterion.weight_dict
61
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
62
+ time_meters["model_forward_time"].update(time.time() - timer_start)
63
+
64
+ timer_start = time.time()
65
+ optimizer.zero_grad()
66
+ losses.backward()
67
+ # except:
68
+ # pdb.set_trace()
69
+ if opt.grad_clip > 0:
70
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
71
+ optimizer.step()
72
+ time_meters["model_backward_time"].update(time.time() - timer_start)
73
+
74
+ loss_dict["loss_overall"] = float(losses) # for logging only
75
+ for k, v in loss_dict.items():
76
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
77
+
78
+ timer_dataloading = time.time()
79
+
80
+ # print/add logs
81
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
82
+ for k, v in loss_meters.items():
83
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
84
+
85
+ to_write = opt.train_log_txt_formatter.format(
86
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
87
+ epoch=epoch_i+1,
88
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
89
+ with open(opt.train_log_filepath, "a") as f:
90
+ f.write(to_write)
91
+
92
+ logger.info("Epoch time stats:")
93
+ for name, meter in time_meters.items():
94
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
95
+ logger.info(f"{name} ==> {d}")
96
+
97
+
98
+ def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
99
+ if opt.device.type == "cuda":
100
+ logger.info("CUDA enabled.")
101
+ model.to(opt.device)
102
+
103
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
104
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
105
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
106
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
107
+
108
+ train_loader = DataLoader(
109
+ train_dataset,
110
+ collate_fn=start_end_collate_mr,
111
+ batch_size=opt.bsz,
112
+ num_workers=opt.num_workers,
113
+ shuffle=True,
114
+ pin_memory=opt.pin_memory
115
+ )
116
+
117
+ if ('tal' in opt.train_path) or ('mq' in opt.train_path):
118
+ cls = {
119
+ 'src_cls': train_dataset.src_cls.cuda(),
120
+ 'src_cls_mask': train_dataset.src_cls_mask.cuda(),}
121
+ else:
122
+ cls = None
123
+
124
+ prev_best_score = 0.
125
+ es_cnt = 0
126
+ if opt.start_epoch is None:
127
+ start_epoch = -1 if opt.eval_init else 0
128
+ else:
129
+ start_epoch = opt.start_epoch
130
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
131
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
132
+ if epoch_i > -1:
133
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer, cls)
134
+ lr_scheduler.step()
135
+ eval_epoch_interval = opt.eval_epoch
136
+ if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
137
+ with torch.no_grad():
138
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
139
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
140
+
141
+ # log
142
+ to_write = opt.eval_log_txt_formatter.format(
143
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
144
+ epoch=epoch_i,
145
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
146
+ eval_metrics_str=json.dumps(metrics_no_nms))
147
+
148
+ with open(opt.eval_log_filepath, "a") as f:
149
+ f.write(to_write)
150
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
151
+ if metrics_nms is not None:
152
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
153
+
154
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
155
+ for k, v in metrics["brief"].items():
156
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
157
+
158
+ # stop_score = metrics["brief"]["MR-full-mAP"]
159
+ # pdb.set_trace()
160
+ stop_score = metrics["brief"][opt.main_metric]
161
+ if stop_score > prev_best_score:
162
+ es_cnt = 0
163
+ prev_best_score = stop_score
164
+
165
+ checkpoint = {
166
+ "model": model.state_dict(),
167
+ "optimizer": optimizer.state_dict(),
168
+ "lr_scheduler": lr_scheduler.state_dict(),
169
+ "epoch": epoch_i,
170
+ "opt": opt
171
+ }
172
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
173
+
174
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
175
+ for src, tgt in zip(latest_file_paths, best_file_paths):
176
+ os.renames(src, tgt)
177
+ logger.info("The checkpoint file has been updated.")
178
+ else:
179
+ es_cnt += 1
180
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
181
+ with open(opt.train_log_filepath, "a") as f:
182
+ f.write(f"Early Stop at epoch {epoch_i}")
183
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
184
+ break
185
+
186
+ # save ckpt
187
+ checkpoint = {
188
+ "model": model.state_dict(),
189
+ "optimizer": optimizer.state_dict(),
190
+ "lr_scheduler": lr_scheduler.state_dict(),
191
+ "epoch": epoch_i,
192
+ "opt": opt
193
+ }
194
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
195
+
196
+ if (epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies
197
+ checkpoint = {
198
+ "model": model.state_dict(),
199
+ "optimizer": optimizer.state_dict(),
200
+ "epoch": epoch_i,
201
+ "opt": opt
202
+ }
203
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
204
+
205
+ if opt.debug:
206
+ break
207
+
208
+ tb_writer.close()
209
+
210
+
211
+ def start_training():
212
+ logger.info("Setup config, data and model...")
213
+ opt = BaseOptions().parse()
214
+ set_seed(opt.seed)
215
+ if opt.debug: # keep the model run deterministically
216
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
217
+ # Enable this only when input size is fixed.
218
+ cudnn.benchmark = False
219
+ cudnn.deterministic = True
220
+
221
+ dataset_config = dict(
222
+ dset_name=opt.dset_name,
223
+ data_path=opt.train_path,
224
+ v_feat_dirs=opt.v_feat_dirs,
225
+ q_feat_dir=opt.t_feat_dir,
226
+ v_feat_dim=opt.v_feat_dim,
227
+ q_feat_dim=opt.t_feat_dim,
228
+ q_feat_type="last_hidden_state",
229
+ max_q_l=opt.max_q_l,
230
+ max_v_l=opt.max_v_l,
231
+ ctx_mode=opt.ctx_mode,
232
+ data_ratio=opt.data_ratio,
233
+ normalize_v=not opt.no_norm_vfeat,
234
+ normalize_t=not opt.no_norm_tfeat,
235
+ clip_len=opt.clip_length,
236
+ max_windows=opt.max_windows,
237
+ span_loss_type=opt.span_loss_type,
238
+ txt_drop_ratio=opt.txt_drop_ratio,
239
+ use_cache=opt.use_cache,
240
+ add_easy_negative=opt.add_easy_negative,
241
+ easy_negative_only=opt.easy_negative_only
242
+ )
243
+
244
+ dataset_config["data_path"] = opt.train_path
245
+ train_dataset = DatasetVLP(**dataset_config)
246
+
247
+ if opt.eval_path is not None:
248
+ dataset_config["data_path"] = opt.eval_path
249
+ dataset_config["txt_drop_ratio"] = 0
250
+ dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
251
+ # dataset_config["load_labels"] = False # uncomment to calculate eval loss
252
+ eval_dataset = DatasetVLP(**dataset_config)
253
+ else:
254
+ eval_dataset = None
255
+
256
+ if opt.lr_warmup > 0:
257
+ opt.lr_warmup = opt.n_epoch
258
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
259
+ logger.info(f"Model {model}")
260
+ count_parameters(model)
261
+ logger.info("Start Training...")
262
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
263
+ return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
264
+
265
+
266
+ if __name__ == '__main__':
267
+ best_ckpt_path, eval_split_name, eval_path, debug = start_training()
268
+ if not debug:
269
+ input_args = ["--resume", best_ckpt_path,
270
+ "--eval_split_name", eval_split_name,
271
+ "--eval_path", eval_path]
272
+
273
+ import sys
274
+ sys.argv[1:] = input_args
275
+ logger.info("\n\n\nFINISHED TRAINING!!!")
276
+ logger.info("Evaluating model at {}".format(best_ckpt_path))
277
+ logger.info("Input args {}".format(sys.argv[1:]))
278
+ start_inference()
main/train_vlp_ddp.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pdb
3
+ import sys
4
+ import time
5
+ import json
6
+ import pprint
7
+ import random
8
+ import numpy as np
9
+ from tqdm import tqdm, trange
10
+ from collections import defaultdict
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.distributed as dist
15
+ import torch.backends.cudnn as cudnn
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from torch.utils.data.distributed import DistributedSampler
19
+
20
+ sys.path.append('/data/home/qinghonglin/univtg')
21
+ from main.config import BaseOptions, setup_model
22
+ from main.dataset import \
23
+ DatasetMR, DatasetVLP, start_end_collate_mr, prepare_batch_inputs_mr
24
+ from main.inference_mr import eval_epoch, start_inference
25
+ from utils.basic_utils import set_seed, AverageMeter, dict_to_markdown
26
+ from utils.model_utils import count_parameters
27
+
28
+ import logging
29
+ logger = logging.getLogger(__name__)
30
+ logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
31
+ datefmt="%Y-%m-%d %H:%M:%S",
32
+ level=logging.INFO)
33
+
34
+ def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer):
35
+ logger.info(f"[Epoch {epoch_i+1}]")
36
+ model.train()
37
+ criterion.train()
38
+
39
+ # init meters
40
+ time_meters = defaultdict(AverageMeter)
41
+ loss_meters = defaultdict(AverageMeter)
42
+
43
+ num_training_examples = len(train_loader)
44
+ timer_dataloading = time.time()
45
+ for batch_idx, batch in tqdm(enumerate(train_loader),
46
+ desc="Training Iteration",
47
+ total=num_training_examples):
48
+ time_meters["dataloading_time"].update(time.time() - timer_dataloading)
49
+
50
+ timer_start = time.time()
51
+ model_inputs, targets = prepare_batch_inputs_mr(batch[1], torch.device("cuda", int(opt.local_rank)), non_blocking=opt.pin_memory)
52
+ time_meters["prepare_inputs_time"].update(time.time() - timer_start)
53
+
54
+ timer_start = time.time()
55
+
56
+ # try:
57
+ outputs = model(**model_inputs)
58
+ loss_dict = criterion(outputs, targets)
59
+ weight_dict = criterion.weight_dict
60
+ losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
61
+ time_meters["model_forward_time"].update(time.time() - timer_start)
62
+
63
+ timer_start = time.time()
64
+ optimizer.zero_grad()
65
+ losses.backward()
66
+
67
+ if opt.grad_clip > 0:
68
+ nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
69
+ optimizer.step()
70
+ time_meters["model_backward_time"].update(time.time() - timer_start)
71
+
72
+ loss_dict["loss_overall"] = float(losses) # for logging only
73
+ for k, v in loss_dict.items():
74
+ loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v))
75
+
76
+ timer_dataloading = time.time()
77
+
78
+ # print/add logs
79
+ if int(opt.local_rank) in [0, -1]:
80
+ tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1)
81
+ for k, v in loss_meters.items():
82
+ tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1)
83
+
84
+ to_write = opt.train_log_txt_formatter.format(
85
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
86
+ epoch=epoch_i+1,
87
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()]))
88
+ with open(opt.train_log_filepath, "a") as f:
89
+ f.write(to_write)
90
+
91
+ logger.info("Epoch time stats:")
92
+ for name, meter in time_meters.items():
93
+ d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]}
94
+ logger.info(f"{name} ==> {d}")
95
+
96
+
97
+ def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt):
98
+ if int(opt.local_rank) in [0, -1]:
99
+ tb_writer = SummaryWriter(opt.tensorboard_log_dir)
100
+ tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None))
101
+ opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
102
+ opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n"
103
+ else:
104
+ tb_writer = None
105
+
106
+ train_loader = DataLoader(
107
+ train_dataset,
108
+ collate_fn=start_end_collate_mr,
109
+ batch_size=opt.bsz,
110
+ num_workers=opt.num_workers,
111
+ # shuffle=True,
112
+ pin_memory=opt.pin_memory,
113
+ sampler=DistributedSampler(train_dataset)
114
+ )
115
+
116
+ prev_best_score = 0.
117
+ es_cnt = 0
118
+ if opt.start_epoch is None:
119
+ start_epoch = -1 if opt.eval_init else 0
120
+ else:
121
+ start_epoch = opt.start_epoch
122
+ save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name)
123
+ for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
124
+ if epoch_i > -1:
125
+ train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer)
126
+ lr_scheduler.step()
127
+ eval_epoch_interval = opt.eval_epoch
128
+ if int(opt.local_rank) in [0, -1] and opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0:
129
+ with torch.no_grad():
130
+ metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \
131
+ eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer)
132
+
133
+ # log
134
+ to_write = opt.eval_log_txt_formatter.format(
135
+ time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
136
+ epoch=epoch_i,
137
+ loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]),
138
+ eval_metrics_str=json.dumps(metrics_no_nms))
139
+
140
+ if int(opt.local_rank) in [0, -1]:
141
+ with open(opt.eval_log_filepath, "a") as f:
142
+ f.write(to_write)
143
+ logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4)))
144
+ if metrics_nms is not None:
145
+ logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4)))
146
+
147
+ metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
148
+ for k, v in metrics["brief"].items():
149
+ tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1)
150
+
151
+ # stop_score = metrics["brief"]["MR-full-mAP"]
152
+ # pdb.set_trace()
153
+ stop_score = metrics["brief"][opt.main_metric]
154
+ if stop_score > prev_best_score:
155
+ es_cnt = 0
156
+ prev_best_score = stop_score
157
+
158
+ checkpoint = {
159
+ "model": model.state_dict(),
160
+ "optimizer": optimizer.state_dict(),
161
+ "lr_scheduler": lr_scheduler.state_dict(),
162
+ "epoch": epoch_i,
163
+ "opt": opt
164
+ }
165
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"))
166
+
167
+ best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
168
+ for src, tgt in zip(latest_file_paths, best_file_paths):
169
+ os.renames(src, tgt)
170
+ logger.info("The checkpoint file has been updated.")
171
+ else:
172
+ es_cnt += 1
173
+ if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop
174
+ with open(opt.train_log_filepath, "a") as f:
175
+ f.write(f"Early Stop at epoch {epoch_i}")
176
+ logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n")
177
+ break
178
+
179
+ # save ckpt
180
+ checkpoint = {
181
+ "model": model.state_dict(),
182
+ "optimizer": optimizer.state_dict(),
183
+ "lr_scheduler": lr_scheduler.state_dict(),
184
+ "epoch": epoch_i,
185
+ "opt": opt
186
+ }
187
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt"))
188
+
189
+ if int(opt.local_rank) in [0, -1] and ((epoch_i + 1) % opt.save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0): # additional copies
190
+ checkpoint = {
191
+ "model": model.state_dict(),
192
+ "optimizer": optimizer.state_dict(),
193
+ "epoch": epoch_i,
194
+ "opt": opt
195
+ }
196
+ torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt"))
197
+
198
+ if opt.debug:
199
+ break
200
+
201
+ if int(opt.local_rank) in [0, -1]:
202
+ tb_writer.close()
203
+
204
+
205
+ def start_training():
206
+ # logger.info("Setup config, data and model...")
207
+ opt = BaseOptions().parse()
208
+ set_seed(opt.seed)
209
+ if opt.debug: # keep the model run deterministically
210
+ # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
211
+ # Enable this only when input size is fixed.
212
+ cudnn.benchmark = False
213
+ cudnn.deterministic = True
214
+
215
+ local_rank = int(opt.local_rank)
216
+ dist.init_process_group(backend='nccl')
217
+
218
+ torch.cuda.set_device(local_rank)
219
+ device = torch.device("cuda", local_rank)
220
+
221
+ dataset_config = dict(
222
+ dset_name=opt.dset_name,
223
+ data_path=opt.train_path,
224
+ v_feat_dirs=opt.v_feat_dirs,
225
+ q_feat_dir=opt.t_feat_dir,
226
+ v_feat_dim=opt.v_feat_dim,
227
+ q_feat_dim=opt.t_feat_dim,
228
+ q_feat_type="last_hidden_state",
229
+ max_q_l=opt.max_q_l,
230
+ max_v_l=opt.max_v_l,
231
+ ctx_mode=opt.ctx_mode,
232
+ data_ratio=opt.data_ratio,
233
+ normalize_v=not opt.no_norm_vfeat,
234
+ normalize_t=not opt.no_norm_tfeat,
235
+ clip_len=opt.clip_length,
236
+ max_windows=opt.max_windows,
237
+ span_loss_type=opt.span_loss_type,
238
+ txt_drop_ratio=opt.txt_drop_ratio,
239
+ use_cache=opt.use_cache,
240
+ add_easy_negative=opt.add_easy_negative,
241
+ easy_negative_only=opt.easy_negative_only
242
+ )
243
+
244
+ dataset_config["data_path"] = opt.train_path
245
+ train_dataset = DatasetVLP(**dataset_config)
246
+
247
+ if opt.eval_path is not None:
248
+ # perform zero-shot on qvhl.
249
+ dataset_config["data_path"] = opt.eval_path
250
+ dataset_config["txt_drop_ratio"] = 0
251
+ if len(dataset_config["v_feat_dirs"]) == 1:
252
+ dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_clip"]
253
+ elif len(dataset_config["v_feat_dirs"]) == 2:
254
+ dataset_config["v_feat_dirs"] = ["data/qvhighlights/vid_slowfast", "data/qvhighlights/vid_clip"]
255
+ else:
256
+ raise NotImplementedError
257
+ dataset_config["q_feat_dir"] = "data/qvhighlights/txt_clip"
258
+ dataset_config["data_ratio"] = 1
259
+ # dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("txt_clip_asr", "txt_clip").replace("txt_clip_cap", "txt_clip") # for pretraining
260
+ eval_dataset = DatasetMR(**dataset_config)
261
+ else:
262
+ eval_dataset = None
263
+
264
+ if opt.lr_warmup > 0:
265
+ # total_steps = opt.n_epoch * len(train_dataset) // opt.bsz
266
+ total_steps = opt.n_epoch
267
+ warmup_steps = opt.lr_warmup if opt.lr_warmup > 1 else int(opt.lr_warmup * total_steps)
268
+ opt.lr_warmup = [warmup_steps, total_steps]
269
+ model, criterion, optimizer, lr_scheduler = setup_model(opt)
270
+
271
+ model.to(device)
272
+ logger.info(f"Using {torch.cuda.device_count()} GPUs.")
273
+ model = torch.nn.parallel.DistributedDataParallel(model,
274
+ device_ids=[local_rank],
275
+ output_device=local_rank,
276
+ find_unused_parameters=True)
277
+
278
+ if int(opt.local_rank) in [0, -1]:
279
+ logger.info(f"Model {model}")
280
+ count_parameters(model)
281
+ logger.info("Start Training...")
282
+ train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt)
283
+ # return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug
284
+ return
285
+
286
+ if __name__ == '__main__':
287
+ # best_ckpt_path, eval_split_name, eval_path, debug = start_training()
288
+ start_training()
model/base.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
208
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
209
+
210
+ losses = {}
211
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
212
+ losses['loss_g'] = loss_giou.mean()
213
+ return losses
214
+
215
+ def loss_labels(self, outputs, targets, indices, log=True):
216
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
217
+ mask = targets['timestamp_mask'].bool()
218
+ mask_valid = targets['timestamp_window'].bool()
219
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
220
+ target_classes[mask_valid] = 1
221
+ # target_classes = targets['timestamp_window'] # soft cls.
222
+ target_classes.float()
223
+ # pdb.set_trace()
224
+
225
+ weights = torch.zeros_like(target_classes).float()
226
+ weights[mask] = self.empty_weight[1]
227
+ weights[mask_valid] = self.empty_weight[0]
228
+
229
+ # pdb.set_trace()
230
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
231
+ return {"loss_f": loss_ce.sum() / mask.sum()}
232
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
233
+
234
+ def loss_saliency(self, outputs, targets, indices, log=True):
235
+ """higher scores for positive clips"""
236
+ if "saliency_pos_labels" not in targets:
237
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
238
+ saliency_scores = targets["saliency_scores"]
239
+ if saliency_scores.sum() == 0:
240
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
241
+
242
+ # * inter-vid mode
243
+ vid_mem_proj = outputs["vid_mem_proj"]
244
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
245
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
246
+
247
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
248
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
249
+ sim = sim_matrix(vid_feats, txt_feats)
250
+
251
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
252
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
253
+
254
+ # sum over positives
255
+ idiag = torch.diag(i_logsm)
256
+ jdiag = torch.diag(j_logsm)
257
+ loss_i = idiag.sum() / len(idiag)
258
+ loss_j = jdiag.sum() / len(jdiag)
259
+
260
+ loss_saliency_inter = - loss_i - loss_j
261
+
262
+ # * intra-vid mode
263
+ mask = targets['timestamp_mask']
264
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
265
+ neg_indices_in = (saliency_scores < selected_scores)
266
+ neg_indices_in[batch_indices, pos_indices] = True
267
+ mask_invalid = neg_indices_in * mask.bool()
268
+
269
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
270
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
271
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
272
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
273
+
274
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
275
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
276
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
277
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
278
+
279
+ loss_saliency_intra = - loss_in_i - loss_in_j
280
+
281
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
282
+
283
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
284
+ """higher scores for positive clips"""
285
+ if "saliency_pos_labels" not in targets:
286
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
287
+ saliency_scores = targets["saliency_scores"]
288
+ if saliency_scores.sum() == 0:
289
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
290
+
291
+ # * inter-vid mode
292
+ vid_mem_proj = outputs["vid_mem_proj"]
293
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
294
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
295
+
296
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
297
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
298
+ sim = sim_matrix(vid_feats, txt_feats)
299
+
300
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
301
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
302
+
303
+ # sum over positives
304
+ idiag = torch.diag(i_logsm)
305
+ jdiag = torch.diag(j_logsm)
306
+ loss_i = idiag.sum() / len(idiag)
307
+ loss_j = jdiag.sum() / len(jdiag)
308
+
309
+ loss_saliency_inter = - loss_i - loss_j
310
+
311
+ # * intra-vid mode
312
+ if 'cls_idx' not in targets.keys(): # eval
313
+ return {"loss_s_inter": loss_saliency_inter}
314
+
315
+ cls_indices = targets['cls_idx'].bool()
316
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
317
+ sim_cls = sim_matrix(vid_feats, cls_feats)
318
+
319
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
320
+ idiag_cls = i_logsm_cls[cls_indices]
321
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
322
+
323
+ loss_saliency_intra = - loss_cls_i
324
+
325
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
326
+
327
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
328
+ loss_map = {
329
+ "spans": self.loss_spans,
330
+ "labels": self.loss_labels,
331
+ "saliency": self.loss_saliency,
332
+ "saliency_cls": self.loss_saliency_cls,
333
+ }
334
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
335
+ return loss_map[loss](outputs, targets, indices, **kwargs)
336
+
337
+ def forward(self, outputs, targets, hl_only=False):
338
+ """ This performs the loss computation.
339
+ Parameters:
340
+ outputs: dict of tensors, see the output specification of the model for the format
341
+ targets: list of dicts, such that len(targets) == batch_size.
342
+ The expected keys in each dict depends on the losses applied, see each loss' doc
343
+ """
344
+ indices = None
345
+ # Compute all the requested losses
346
+ losses = {}
347
+ for loss in self.losses:
348
+ losses.update(self.get_loss(loss, outputs, targets, indices))
349
+
350
+ return losses
351
+
352
+ class MLP(nn.Module):
353
+ """ Very simple multi-layer perceptron (also called FFN)"""
354
+
355
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
356
+ super().__init__()
357
+ self.num_layers = num_layers
358
+ h = [hidden_dim] * (num_layers - 1)
359
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
360
+
361
+ def forward(self, x):
362
+ for i, layer in enumerate(self.layers):
363
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
364
+ return x
365
+
366
+ class Conv(nn.Module):
367
+ """ Very simple multi-layer perceptron (also called FFN)"""
368
+
369
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
370
+ super().__init__()
371
+ self.num_layers = num_layers
372
+ h = [hidden_dim] * (num_layers - 1)
373
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
374
+ self.layers = nn.ModuleList(
375
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
376
+ for n, k in zip([input_dim] + h, h + [output_dim]))
377
+ def forward(self, x):
378
+ x = x.permute(0,2,1)
379
+ for i, layer in enumerate(self.layers):
380
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
381
+ return x.permute(0, 2, 1)
382
+
383
+ class LinearLayer(nn.Module):
384
+ """linear layer configurable with layer normalization, dropout, ReLU."""
385
+
386
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
387
+ super(LinearLayer, self).__init__()
388
+ self.relu = relu
389
+ self.layer_norm = layer_norm
390
+ if layer_norm:
391
+ self.LayerNorm = nn.LayerNorm(in_hsz)
392
+ layers = [
393
+ nn.Dropout(dropout),
394
+ nn.Linear(in_hsz, out_hsz)
395
+ ]
396
+ self.net = nn.Sequential(*layers)
397
+
398
+ def forward(self, x):
399
+ """(N, L, D)"""
400
+ if self.layer_norm:
401
+ x = self.LayerNorm(x)
402
+ x = self.net(x)
403
+ if self.relu:
404
+ x = F.relu(x, inplace=True)
405
+ return x # (N, L, D)
406
+
407
+
408
+ def build_model(args):
409
+ device = torch.device(args.device)
410
+
411
+ transformer = build_transformer(args)
412
+ position_embedding, txt_position_embedding = build_position_encoding(args)
413
+
414
+ model = Model(
415
+ transformer,
416
+ position_embedding,
417
+ txt_position_embedding,
418
+ txt_dim=args.t_feat_dim,
419
+ vid_dim=args.v_feat_dim,
420
+ input_dropout=args.input_dropout,
421
+ span_loss_type=args.span_loss_type,
422
+ use_txt_pos=args.use_txt_pos,
423
+ n_input_proj=args.n_input_proj,
424
+ )
425
+
426
+ matcher = build_matcher(args)
427
+ weight_dict = {"loss_b": args.b_loss_coef,
428
+ "loss_g": args.g_loss_coef,
429
+ "loss_f": args.f_loss_coef,
430
+ "loss_s_intra": args.s_loss_intra_coef,
431
+ "loss_s_inter": args.s_loss_inter_coef}
432
+
433
+ if args.dset_type in ['mr', 'vlp']:
434
+ if 'tal' not in args.train_path:
435
+ losses = ['spans', 'labels', 'saliency']
436
+ else:
437
+ losses = ['spans', 'labels', 'saliency_cls']
438
+ elif args.dset_type in ['hl', 'vs']:
439
+ losses = ['labels', 'saliency']
440
+
441
+ criterion = SetCriterion(
442
+ matcher=matcher,
443
+ weight_dict=weight_dict, losses=losses,
444
+ eos_coef=args.eos_coef, temperature=args.temperature,
445
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
446
+ saliency_margin=args.saliency_margin,
447
+ )
448
+ criterion.to(device)
449
+ return model, criterion
model/base_albef.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder import build_transformer, Transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer_mm, transformer_v, transformer_t, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer_mm
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer_mm.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ self.transformer_v = transformer_v
103
+ self.transformer_t = transformer_t
104
+
105
+ # MLP Projector
106
+ self.weightedpool = WeightedPool(hidden_dim)
107
+
108
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
109
+ bs = src_vid.shape[0]
110
+ src_vid = self.input_vid_proj(src_vid)
111
+ src_txt = self.input_txt_proj(src_txt)
112
+ if src_cls is not None:
113
+ src_cls = self.input_txt_proj(src_cls)
114
+
115
+ # pos embed.
116
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
117
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
118
+
119
+ src_vid = self.transformer_v(src_vid, ~src_vid_mask.bool(), pos_vid)
120
+ src_txt = self.transformer_t(src_txt, ~src_txt_mask.bool(), pos_txt)
121
+
122
+ # type token.
123
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
124
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
125
+ if src_cls is not None:
126
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
127
+
128
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
129
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
130
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
131
+
132
+ memory = self.transformer(src, ~mask, pos)
133
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
134
+
135
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
136
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
137
+
138
+ if self.span_loss_type == "l1":
139
+ outputs_coord = outputs_coord.sigmoid()
140
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
141
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
142
+ outputs_coord = outputs_coord * idx_mask
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
147
+ 'src_vid_mask': src_vid_mask}
148
+
149
+ vid_mem_proj = src_vid
150
+
151
+ # word-level -> sentence-level
152
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
153
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
154
+
155
+ out["vid_mem_proj"] = vid_mem_proj
156
+ out["txt_mem_proj"] = txt_mem_proj
157
+ if src_cls is not None:
158
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
159
+ out["cls_mem_proj"] = cls_mem_proj
160
+ out["saliency_scores"] = sim
161
+ return out
162
+
163
+ class SetCriterion(nn.Module):
164
+ """ This class computes the loss for DETR.
165
+ The process happens in two steps:
166
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
167
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
168
+ """
169
+
170
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
171
+ saliency_margin=1):
172
+ """ Create the criterion.
173
+ Parameters:
174
+ matcher: module able to compute a matching between targets and proposals
175
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
176
+ eos_coef: relative classification weight applied to the no-object category
177
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
178
+ temperature: float, temperature for NCE loss
179
+ span_loss_type: str, [l1, ce]
180
+ max_v_l: int,
181
+ saliency_margin: float
182
+ """
183
+ super().__init__()
184
+ self.matcher = matcher
185
+ self.weight_dict = weight_dict
186
+ self.losses = losses
187
+ self.temperature = temperature
188
+ self.span_loss_type = span_loss_type
189
+ self.max_v_l = max_v_l
190
+ self.saliency_margin = saliency_margin
191
+ self.temperature = 0.07
192
+
193
+ # foreground and background classification
194
+ self.foreground_label = 0
195
+ self.background_label = 1
196
+ self.eos_coef = eos_coef
197
+ empty_weight = torch.ones(2)
198
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
199
+ self.register_buffer('empty_weight', empty_weight)
200
+
201
+ def loss_spans(self, outputs, targets, indices):
202
+ assert 'pred_spans' in outputs
203
+
204
+ start_spans = targets['timestamp']
205
+ pred_spans = outputs['pred_spans']
206
+ src_spans = start_spans + pred_spans
207
+ gt_spans = targets['span_labels_nn']
208
+
209
+ mask = targets['timestamp_mask'].bool()
210
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
211
+ mask_valid = targets['timestamp_window'].bool()
212
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
213
+
214
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
215
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
216
+
217
+ losses = {}
218
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
219
+ losses['loss_g'] = loss_giou.mean()
220
+ return losses
221
+
222
+ def loss_labels(self, outputs, targets, indices, log=True):
223
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
224
+ mask = targets['timestamp_mask'].bool()
225
+ mask_valid = targets['timestamp_window'].bool()
226
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
227
+ target_classes[mask_valid] = 1
228
+ # target_classes = targets['timestamp_window'] # soft cls.
229
+ target_classes.float()
230
+ # pdb.set_trace()
231
+
232
+ weights = torch.zeros_like(target_classes).float()
233
+ weights[mask] = self.empty_weight[1]
234
+ weights[mask_valid] = self.empty_weight[0]
235
+
236
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
237
+ return {"loss_f": loss_ce.sum() / mask.sum()}
238
+
239
+ def loss_saliency(self, outputs, targets, indices, log=True):
240
+ """higher scores for positive clips"""
241
+ if "saliency_pos_labels" not in targets:
242
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
243
+ saliency_scores = targets["saliency_scores"]
244
+ if saliency_scores.sum() == 0:
245
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
246
+
247
+ # * inter-vid mode
248
+ vid_mem_proj = outputs["vid_mem_proj"]
249
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
250
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
251
+
252
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
253
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
254
+ sim = sim_matrix(vid_feats, txt_feats)
255
+
256
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
257
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
258
+
259
+ # sum over positives
260
+ idiag = torch.diag(i_logsm)
261
+ jdiag = torch.diag(j_logsm)
262
+ loss_i = idiag.sum() / len(idiag)
263
+ loss_j = jdiag.sum() / len(jdiag)
264
+
265
+ loss_saliency_inter = - loss_i - loss_j
266
+
267
+ # * intra-vid mode
268
+ mask = targets['timestamp_mask']
269
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
270
+ neg_indices_in = (saliency_scores < selected_scores)
271
+ neg_indices_in[batch_indices, pos_indices] = True
272
+ mask_invalid = neg_indices_in * mask.bool()
273
+
274
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
275
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
276
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
277
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
278
+
279
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
280
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
281
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
282
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
283
+
284
+ loss_saliency_intra = - loss_in_i - loss_in_j
285
+
286
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
287
+
288
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
289
+ """higher scores for positive clips"""
290
+ if "saliency_pos_labels" not in targets:
291
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
292
+ saliency_scores = targets["saliency_scores"]
293
+ if saliency_scores.sum() == 0:
294
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
295
+
296
+ # * inter-vid mode
297
+ vid_mem_proj = outputs["vid_mem_proj"]
298
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
299
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
300
+
301
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
302
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
303
+ sim = sim_matrix(vid_feats, txt_feats)
304
+
305
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
306
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
307
+
308
+ # sum over positives
309
+ idiag = torch.diag(i_logsm)
310
+ jdiag = torch.diag(j_logsm)
311
+ loss_i = idiag.sum() / len(idiag)
312
+ loss_j = jdiag.sum() / len(jdiag)
313
+
314
+ loss_saliency_inter = - loss_i - loss_j
315
+
316
+ # * intra-vid mode
317
+ if 'cls_idx' not in targets.keys(): # eval
318
+ return {"loss_s_inter": loss_saliency_inter}
319
+
320
+ cls_indices = targets['cls_idx'].bool()
321
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
322
+ sim_cls = sim_matrix(vid_feats, cls_feats)
323
+
324
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
325
+ idiag_cls = i_logsm_cls[cls_indices]
326
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
327
+
328
+ loss_saliency_intra = - loss_cls_i
329
+
330
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
331
+
332
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
333
+ loss_map = {
334
+ "spans": self.loss_spans,
335
+ "labels": self.loss_labels,
336
+ "saliency": self.loss_saliency,
337
+ "saliency_cls": self.loss_saliency_cls,
338
+ }
339
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
340
+ return loss_map[loss](outputs, targets, indices, **kwargs)
341
+
342
+ def forward(self, outputs, targets, hl_only=False):
343
+ """ This performs the loss computation.
344
+ Parameters:
345
+ outputs: dict of tensors, see the output specification of the model for the format
346
+ targets: list of dicts, such that len(targets) == batch_size.
347
+ The expected keys in each dict depends on the losses applied, see each loss' doc
348
+ """
349
+ indices = None
350
+ # Compute all the requested losses
351
+ losses = {}
352
+ for loss in self.losses:
353
+ losses.update(self.get_loss(loss, outputs, targets, indices))
354
+
355
+ return losses
356
+
357
+ class MLP(nn.Module):
358
+ """ Very simple multi-layer perceptron (also called FFN)"""
359
+
360
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
361
+ super().__init__()
362
+ self.num_layers = num_layers
363
+ h = [hidden_dim] * (num_layers - 1)
364
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
365
+
366
+ def forward(self, x):
367
+ for i, layer in enumerate(self.layers):
368
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
369
+ return x
370
+
371
+ class Conv(nn.Module):
372
+ """ Very simple multi-layer perceptron (also called FFN)"""
373
+
374
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
375
+ super().__init__()
376
+ self.num_layers = num_layers
377
+ h = [hidden_dim] * (num_layers - 1)
378
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
379
+ self.layers = nn.ModuleList(
380
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
381
+ for n, k in zip([input_dim] + h, h + [output_dim]))
382
+ def forward(self, x):
383
+ x = x.permute(0,2,1)
384
+ for i, layer in enumerate(self.layers):
385
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
386
+ return x.permute(0, 2, 1)
387
+
388
+ class LinearLayer(nn.Module):
389
+ """linear layer configurable with layer normalization, dropout, ReLU."""
390
+
391
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
392
+ super(LinearLayer, self).__init__()
393
+ self.relu = relu
394
+ self.layer_norm = layer_norm
395
+ if layer_norm:
396
+ self.LayerNorm = nn.LayerNorm(in_hsz)
397
+ layers = [
398
+ nn.Dropout(dropout),
399
+ nn.Linear(in_hsz, out_hsz)
400
+ ]
401
+ self.net = nn.Sequential(*layers)
402
+
403
+ def forward(self, x):
404
+ """(N, L, D)"""
405
+ if self.layer_norm:
406
+ x = self.LayerNorm(x)
407
+ x = self.net(x)
408
+ if self.relu:
409
+ x = F.relu(x, inplace=True)
410
+ return x # (N, L, D)
411
+
412
+
413
+ def build_model(args):
414
+ device = torch.device(args.device)
415
+
416
+ transformer_mm = build_transformer(args)
417
+ transformer_v = Transformer(
418
+ d_model=args.hidden_dim,
419
+ dropout=args.dropout,
420
+ nhead=args.nheads,
421
+ dim_feedforward=args.dim_feedforward,
422
+ num_encoder_layers=args.sub_enc_layers,
423
+ num_decoder_layers=args.dec_layers,
424
+ normalize_before=args.pre_norm,
425
+ return_intermediate_dec=True,
426
+ )
427
+ transformer_t = Transformer(
428
+ d_model=args.hidden_dim,
429
+ dropout=args.dropout,
430
+ nhead=args.nheads,
431
+ dim_feedforward=args.dim_feedforward,
432
+ num_encoder_layers=args.sub_enc_layers,
433
+ num_decoder_layers=args.dec_layers,
434
+ normalize_before=args.pre_norm,
435
+ return_intermediate_dec=True,
436
+ )
437
+ # pdb.set_trace()
438
+
439
+ position_embedding, txt_position_embedding = build_position_encoding(args)
440
+
441
+ model = Model(
442
+ transformer_mm,
443
+ transformer_v,
444
+ transformer_t,
445
+ position_embedding,
446
+ txt_position_embedding,
447
+ txt_dim=args.t_feat_dim,
448
+ vid_dim=args.v_feat_dim,
449
+ input_dropout=args.input_dropout,
450
+ span_loss_type=args.span_loss_type,
451
+ use_txt_pos=args.use_txt_pos,
452
+ n_input_proj=args.n_input_proj,
453
+ )
454
+
455
+ matcher = build_matcher(args)
456
+ weight_dict = {"loss_b": args.b_loss_coef,
457
+ "loss_g": args.g_loss_coef,
458
+ "loss_f": args.f_loss_coef,
459
+ "loss_s_intra": args.s_loss_intra_coef,
460
+ "loss_s_inter": args.s_loss_inter_coef}
461
+
462
+ if args.dset_type in ['mr', 'vlp']:
463
+ if 'tal' not in args.train_path:
464
+ losses = ['spans', 'labels', 'saliency']
465
+ else:
466
+ losses = ['spans', 'labels', 'saliency_cls']
467
+ elif args.dset_type in ['hl', 'vs']:
468
+ losses = ['labels', 'saliency']
469
+
470
+ criterion = SetCriterion(
471
+ matcher=matcher,
472
+ weight_dict=weight_dict, losses=losses,
473
+ eos_coef=args.eos_coef, temperature=args.temperature,
474
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
475
+ saliency_margin=args.saliency_margin,
476
+ )
477
+ criterion.to(device)
478
+ return model, criterion
model/base_droppath.py ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
208
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
209
+
210
+ losses = {}
211
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
212
+ losses['loss_g'] = loss_giou.mean()
213
+ return losses
214
+
215
+ def loss_labels(self, outputs, targets, indices, log=True):
216
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
217
+ mask = targets['timestamp_mask'].bool()
218
+ mask_valid = targets['timestamp_window'].bool()
219
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
220
+ target_classes[mask_valid] = 1
221
+ # target_classes = targets['timestamp_window'] # soft cls.
222
+ target_classes.float()
223
+ # pdb.set_trace()
224
+
225
+ weights = torch.zeros_like(target_classes).float()
226
+ weights[mask] = self.empty_weight[1]
227
+ weights[mask_valid] = self.empty_weight[0]
228
+
229
+ # pdb.set_trace()
230
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
231
+ return {"loss_f": loss_ce.sum() / mask.sum()}
232
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
233
+
234
+ def loss_saliency(self, outputs, targets, indices, log=True):
235
+ """higher scores for positive clips"""
236
+ if "saliency_pos_labels" not in targets:
237
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
238
+ saliency_scores = targets["saliency_scores"]
239
+ if saliency_scores.sum() == 0:
240
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
241
+
242
+ # * inter-vid mode
243
+ vid_mem_proj = outputs["vid_mem_proj"]
244
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
245
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
246
+
247
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
248
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
249
+ sim = sim_matrix(vid_feats, txt_feats)
250
+
251
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
252
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
253
+
254
+ # sum over positives
255
+ idiag = torch.diag(i_logsm)
256
+ jdiag = torch.diag(j_logsm)
257
+ loss_i = idiag.sum() / len(idiag)
258
+ loss_j = jdiag.sum() / len(jdiag)
259
+
260
+ loss_saliency_inter = - loss_i - loss_j
261
+
262
+ # * intra-vid mode
263
+ mask = targets['timestamp_mask']
264
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
265
+ neg_indices_in = (saliency_scores < selected_scores)
266
+ neg_indices_in[batch_indices, pos_indices] = True
267
+ mask_invalid = neg_indices_in * mask.bool()
268
+
269
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
270
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
271
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
272
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
273
+
274
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
275
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
276
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
277
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
278
+
279
+ loss_saliency_intra = - loss_in_i - loss_in_j
280
+
281
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
282
+
283
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
284
+ """higher scores for positive clips"""
285
+ if "saliency_pos_labels" not in targets:
286
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
287
+ saliency_scores = targets["saliency_scores"]
288
+ if saliency_scores.sum() == 0:
289
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
290
+
291
+ # * inter-vid mode
292
+ vid_mem_proj = outputs["vid_mem_proj"]
293
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
294
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
295
+
296
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
297
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
298
+ sim = sim_matrix(vid_feats, txt_feats)
299
+
300
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
301
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
302
+
303
+ # sum over positives
304
+ idiag = torch.diag(i_logsm)
305
+ jdiag = torch.diag(j_logsm)
306
+ loss_i = idiag.sum() / len(idiag)
307
+ loss_j = jdiag.sum() / len(jdiag)
308
+
309
+ loss_saliency_inter = - loss_i - loss_j
310
+
311
+ # * intra-vid mode
312
+ if 'cls_idx' not in targets.keys(): # eval
313
+ return {"loss_s_inter": loss_saliency_inter}
314
+
315
+ cls_indices = targets['cls_idx'].bool()
316
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
317
+ sim_cls = sim_matrix(vid_feats, cls_feats)
318
+
319
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
320
+ idiag_cls = i_logsm_cls[cls_indices]
321
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
322
+
323
+ loss_saliency_intra = - loss_cls_i
324
+
325
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
326
+
327
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
328
+ loss_map = {
329
+ "spans": self.loss_spans,
330
+ "labels": self.loss_labels,
331
+ "saliency": self.loss_saliency,
332
+ "saliency_cls": self.loss_saliency_cls,
333
+ }
334
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
335
+ return loss_map[loss](outputs, targets, indices, **kwargs)
336
+
337
+ def forward(self, outputs, targets, hl_only=False):
338
+ """ This performs the loss computation.
339
+ Parameters:
340
+ outputs: dict of tensors, see the output specification of the model for the format
341
+ targets: list of dicts, such that len(targets) == batch_size.
342
+ The expected keys in each dict depends on the losses applied, see each loss' doc
343
+ """
344
+ indices = None
345
+ # Compute all the requested losses
346
+ losses = {}
347
+ for loss in self.losses:
348
+ losses.update(self.get_loss(loss, outputs, targets, indices))
349
+
350
+ return losses
351
+
352
+ class MLP(nn.Module):
353
+ """ Very simple multi-layer perceptron (also called FFN)"""
354
+
355
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
356
+ super().__init__()
357
+ self.num_layers = num_layers
358
+ h = [hidden_dim] * (num_layers - 1)
359
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
360
+
361
+ def forward(self, x):
362
+ for i, layer in enumerate(self.layers):
363
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
364
+ return x
365
+
366
+ class Conv(nn.Module):
367
+ """ Very simple multi-layer perceptron (also called FFN)"""
368
+
369
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
370
+ super().__init__()
371
+ self.num_layers = num_layers
372
+ h = [hidden_dim] * (num_layers - 1)
373
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
374
+ self.layers = nn.ModuleList(
375
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
376
+ for n, k in zip([input_dim] + h, h + [output_dim]))
377
+ def forward(self, x):
378
+ x = x.permute(0,2,1)
379
+ for i, layer in enumerate(self.layers):
380
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
381
+ return x.permute(0, 2, 1)
382
+
383
+ class LinearLayer(nn.Module):
384
+ """linear layer configurable with layer normalization, dropout, ReLU."""
385
+
386
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
387
+ super(LinearLayer, self).__init__()
388
+ self.relu = relu
389
+ self.layer_norm = layer_norm
390
+ if layer_norm:
391
+ self.LayerNorm = nn.LayerNorm(in_hsz)
392
+ layers = [
393
+ nn.Dropout(dropout),
394
+ nn.Linear(in_hsz, out_hsz)
395
+ ]
396
+ self.net = nn.Sequential(*layers)
397
+
398
+ def forward(self, x):
399
+ """(N, L, D)"""
400
+ if self.layer_norm:
401
+ x = self.LayerNorm(x)
402
+ x = self.net(x)
403
+ if self.relu:
404
+ x = F.relu(x, inplace=True)
405
+ return x # (N, L, D)
406
+
407
+
408
+ def build_model(args):
409
+ device = torch.device(args.device)
410
+
411
+ transformer = build_transformer(args)
412
+ position_embedding, txt_position_embedding = build_position_encoding(args)
413
+
414
+ model = Model(
415
+ transformer,
416
+ position_embedding,
417
+ txt_position_embedding,
418
+ txt_dim=args.t_feat_dim,
419
+ vid_dim=args.v_feat_dim,
420
+ input_dropout=args.input_dropout,
421
+ span_loss_type=args.span_loss_type,
422
+ use_txt_pos=args.use_txt_pos,
423
+ n_input_proj=args.n_input_proj,
424
+ )
425
+
426
+ matcher = build_matcher(args)
427
+ weight_dict = {"loss_b": args.b_loss_coef,
428
+ "loss_g": args.g_loss_coef,
429
+ "loss_f": args.f_loss_coef,
430
+ "loss_s_intra": args.s_loss_intra_coef,
431
+ "loss_s_inter": args.s_loss_inter_coef}
432
+
433
+ if args.dset_type in ['mr', 'vlp']:
434
+ if 'tal' not in args.train_path:
435
+ losses = ['spans', 'labels', 'saliency']
436
+ else:
437
+ losses = ['spans', 'labels', 'saliency_cls']
438
+ elif args.dset_type in ['hl', 'vs']:
439
+ losses = ['labels', 'saliency']
440
+
441
+ criterion = SetCriterion(
442
+ matcher=matcher,
443
+ weight_dict=weight_dict, losses=losses,
444
+ eos_coef=args.eos_coef, temperature=args.temperature,
445
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
446
+ saliency_margin=args.saliency_margin,
447
+ )
448
+ criterion.to(device)
449
+ return model, criterion
model/base_droppath_ablation.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
208
+ if weight_abalation_b.sum() == 0:
209
+ return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
210
+
211
+ mask_valid = (mask_valid * weight_abalation_b).bool()
212
+ mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
213
+
214
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
215
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
216
+
217
+ losses = {}
218
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
219
+ losses['loss_g'] = loss_giou.mean()
220
+ return losses
221
+
222
+ def loss_labels(self, outputs, targets, indices, log=True):
223
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
224
+ mask = targets['timestamp_mask'].bool()
225
+ mask_valid = targets['timestamp_window'].bool()
226
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
227
+ target_classes[mask_valid] = 1
228
+ # target_classes = targets['timestamp_window'] # soft cls.
229
+ target_classes.float()
230
+ # pdb.set_trace()
231
+
232
+ weights = torch.zeros_like(target_classes).float()
233
+ weights[mask] = self.empty_weight[1]
234
+ weights[mask_valid] = self.empty_weight[0]
235
+
236
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
237
+
238
+ weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
239
+ if weight_abalation_f.sum() == 0:
240
+ return {"loss_f": torch.tensor(0).cuda()}
241
+
242
+ mask = mask * weight_abalation_f
243
+ loss_ce = loss_ce * weight_abalation_f
244
+ return {"loss_f": loss_ce.sum() / mask.sum()}
245
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
246
+
247
+ def loss_saliency(self, outputs, targets, indices, log=True):
248
+ """higher scores for positive clips"""
249
+ if "saliency_pos_labels" not in targets:
250
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
251
+ saliency_scores = targets["saliency_scores"]
252
+ if saliency_scores.sum() == 0:
253
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
254
+
255
+ # * inter-vid mode
256
+ vid_mem_proj = outputs["vid_mem_proj"]
257
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
258
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
259
+
260
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
261
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
262
+ sim = sim_matrix(vid_feats, txt_feats)
263
+
264
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
265
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
266
+
267
+ # sum over positives
268
+ idiag = torch.diag(i_logsm)
269
+ jdiag = torch.diag(j_logsm)
270
+
271
+ weight_abalation_s = targets['weight_ablation'][:,3].bool()
272
+ if weight_abalation_s.sum() == 0:
273
+ return {"loss_s_inter": torch.tensor(0).cuda(),
274
+ "loss_s_intra": torch.tensor(0).cuda()}
275
+
276
+ _idiag = idiag[weight_abalation_s]
277
+ _jdiag = jdiag[weight_abalation_s]
278
+
279
+ loss_i = _idiag.sum() / len(_idiag)
280
+ loss_j = _jdiag.sum() / len(_jdiag)
281
+
282
+ loss_saliency_inter = - loss_i - loss_j
283
+
284
+ # * intra-vid mode
285
+ mask = targets['timestamp_mask']
286
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
287
+ neg_indices_in = (saliency_scores < selected_scores)
288
+ neg_indices_in[batch_indices, pos_indices] = True
289
+ mask_invalid = neg_indices_in * mask.bool()
290
+
291
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
292
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
293
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
294
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
295
+
296
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
297
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
298
+ _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
299
+ _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
300
+
301
+ loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
302
+ loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
303
+
304
+ loss_saliency_intra = - loss_in_i - loss_in_j
305
+
306
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
307
+
308
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
309
+ """higher scores for positive clips"""
310
+ if "saliency_pos_labels" not in targets:
311
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
312
+ saliency_scores = targets["saliency_scores"]
313
+ if saliency_scores.sum() == 0:
314
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
315
+
316
+ # * inter-vid mode
317
+ vid_mem_proj = outputs["vid_mem_proj"]
318
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
319
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
320
+
321
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
322
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
323
+ sim = sim_matrix(vid_feats, txt_feats)
324
+
325
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
326
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
327
+
328
+ # sum over positives
329
+ idiag = torch.diag(i_logsm)
330
+ jdiag = torch.diag(j_logsm)
331
+ loss_i = idiag.sum() / len(idiag)
332
+ loss_j = jdiag.sum() / len(jdiag)
333
+
334
+ loss_saliency_inter = - loss_i - loss_j
335
+
336
+ # * intra-vid mode
337
+ if 'cls_idx' not in targets.keys(): # eval
338
+ return {"loss_s_inter": loss_saliency_inter}
339
+
340
+ cls_indices = targets['cls_idx'].bool()
341
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
342
+ sim_cls = sim_matrix(vid_feats, cls_feats)
343
+
344
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
345
+ idiag_cls = i_logsm_cls[cls_indices]
346
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
347
+
348
+ loss_saliency_intra = - loss_cls_i
349
+
350
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
351
+
352
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
353
+ loss_map = {
354
+ "spans": self.loss_spans,
355
+ "labels": self.loss_labels,
356
+ "saliency": self.loss_saliency,
357
+ "saliency_cls": self.loss_saliency_cls,
358
+ }
359
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
360
+ return loss_map[loss](outputs, targets, indices, **kwargs)
361
+
362
+ def forward(self, outputs, targets, hl_only=False):
363
+ """ This performs the loss computation.
364
+ Parameters:
365
+ outputs: dict of tensors, see the output specification of the model for the format
366
+ targets: list of dicts, such that len(targets) == batch_size.
367
+ The expected keys in each dict depends on the losses applied, see each loss' doc
368
+ """
369
+ indices = None
370
+ # Compute all the requested losses
371
+ losses = {}
372
+ for loss in self.losses:
373
+ losses.update(self.get_loss(loss, outputs, targets, indices))
374
+
375
+ return losses
376
+
377
+ class MLP(nn.Module):
378
+ """ Very simple multi-layer perceptron (also called FFN)"""
379
+
380
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
381
+ super().__init__()
382
+ self.num_layers = num_layers
383
+ h = [hidden_dim] * (num_layers - 1)
384
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
385
+
386
+ def forward(self, x):
387
+ for i, layer in enumerate(self.layers):
388
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
389
+ return x
390
+
391
+ class Conv(nn.Module):
392
+ """ Very simple multi-layer perceptron (also called FFN)"""
393
+
394
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
395
+ super().__init__()
396
+ self.num_layers = num_layers
397
+ h = [hidden_dim] * (num_layers - 1)
398
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
399
+ self.layers = nn.ModuleList(
400
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
401
+ for n, k in zip([input_dim] + h, h + [output_dim]))
402
+ def forward(self, x):
403
+ x = x.permute(0,2,1)
404
+ for i, layer in enumerate(self.layers):
405
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
406
+ return x.permute(0, 2, 1)
407
+
408
+ class LinearLayer(nn.Module):
409
+ """linear layer configurable with layer normalization, dropout, ReLU."""
410
+
411
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
412
+ super(LinearLayer, self).__init__()
413
+ self.relu = relu
414
+ self.layer_norm = layer_norm
415
+ if layer_norm:
416
+ self.LayerNorm = nn.LayerNorm(in_hsz)
417
+ layers = [
418
+ nn.Dropout(dropout),
419
+ nn.Linear(in_hsz, out_hsz)
420
+ ]
421
+ self.net = nn.Sequential(*layers)
422
+
423
+ def forward(self, x):
424
+ """(N, L, D)"""
425
+ if self.layer_norm:
426
+ x = self.LayerNorm(x)
427
+ x = self.net(x)
428
+ if self.relu:
429
+ x = F.relu(x, inplace=True)
430
+ return x # (N, L, D)
431
+
432
+
433
+ def build_model(args):
434
+ device = torch.device(args.device)
435
+
436
+ transformer = build_transformer(args)
437
+ position_embedding, txt_position_embedding = build_position_encoding(args)
438
+
439
+ model = Model(
440
+ transformer,
441
+ position_embedding,
442
+ txt_position_embedding,
443
+ txt_dim=args.t_feat_dim,
444
+ vid_dim=args.v_feat_dim,
445
+ input_dropout=args.input_dropout,
446
+ span_loss_type=args.span_loss_type,
447
+ use_txt_pos=args.use_txt_pos,
448
+ n_input_proj=args.n_input_proj,
449
+ )
450
+
451
+ matcher = build_matcher(args)
452
+ weight_dict = {"loss_b": args.b_loss_coef,
453
+ "loss_g": args.g_loss_coef,
454
+ "loss_f": args.f_loss_coef,
455
+ "loss_s_intra": args.s_loss_intra_coef,
456
+ "loss_s_inter": args.s_loss_inter_coef}
457
+
458
+ if args.dset_type in ['mr', 'vlp']:
459
+ if 'tal' not in args.train_path:
460
+ losses = ['spans', 'labels', 'saliency']
461
+ else:
462
+ losses = ['spans', 'labels', 'saliency_cls']
463
+ elif args.dset_type in ['hl', 'vs']:
464
+ losses = ['labels', 'saliency']
465
+
466
+ criterion = SetCriterion(
467
+ matcher=matcher,
468
+ weight_dict=weight_dict, losses=losses,
469
+ eos_coef=args.eos_coef, temperature=args.temperature,
470
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
471
+ saliency_margin=args.saliency_margin,
472
+ )
473
+ criterion.to(device)
474
+ return model, criterion
model/base_droppath_qfvs.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
208
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
209
+
210
+ losses = {}
211
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
212
+ losses['loss_g'] = loss_giou.mean()
213
+ return losses
214
+
215
+ def loss_labels(self, outputs, targets, indices, log=True):
216
+ saliency_scores = targets["saliency_scores"]
217
+ if saliency_scores.sum() == 0:
218
+ return {"loss_f": 0.}
219
+
220
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
221
+ target_classes = targets["saliency_scores"].squeeze()
222
+
223
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
224
+ weights[target_classes.bool()] = self.empty_weight[0]
225
+
226
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
227
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
228
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
229
+
230
+ # mask = targets['timestamp_mask'].bool()
231
+ # mask_valid = targets['timestamp_window'].bool()
232
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
233
+ # target_classes[mask_valid] = 1
234
+ # # target_classes = targets['timestamp_window'] # soft cls.
235
+ # target_classes.float()
236
+ # # pdb.set_trace()
237
+
238
+ # weights = torch.zeros_like(target_classes).float()
239
+ # weights[mask] = self.empty_weight[1]
240
+ # weights[mask_valid] = self.empty_weight[0]
241
+
242
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
243
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
244
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
245
+
246
+ def loss_saliency(self, outputs, targets, indices, log=True):
247
+ """higher scores for positive clips"""
248
+ if "saliency_pos_labels" not in targets:
249
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
250
+ saliency_scores = targets["saliency_scores"]
251
+ if saliency_scores.sum() == 0:
252
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
253
+
254
+ # * qfvs mil-nce mode
255
+ pos_indices = saliency_scores.squeeze() > 0
256
+
257
+ sim = outputs['saliency_scores']
258
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
259
+ sim_log = torch.log(sim_soft[pos_indices])
260
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
261
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
262
+
263
+ # * inter-vid mode
264
+ # vid_mem_proj = outputs["vid_mem_proj"]
265
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
266
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
267
+
268
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
269
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
270
+ # sim = sim_matrix(vid_feats, txt_feats)
271
+
272
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
273
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
274
+
275
+ # # sum over positives
276
+ # idiag = torch.diag(i_logsm)
277
+ # jdiag = torch.diag(j_logsm)
278
+ # loss_i = idiag.sum() / len(idiag)
279
+ # loss_j = jdiag.sum() / len(jdiag)
280
+
281
+ # loss_saliency_inter = - loss_i - loss_j
282
+
283
+ # # * intra-vid mode
284
+ # mask = targets['timestamp_mask']
285
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
286
+ # neg_indices_in = (saliency_scores < selected_scores)
287
+ # neg_indices_in[batch_indices, pos_indices] = True
288
+ # mask_invalid = neg_indices_in * mask.bool()
289
+
290
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
291
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
292
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
293
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
294
+
295
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
296
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
297
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
298
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
299
+
300
+ # loss_saliency_intra = - loss_in_i - loss_in_j
301
+
302
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
303
+
304
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
305
+ """higher scores for positive clips"""
306
+ if "saliency_pos_labels" not in targets:
307
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
308
+ saliency_scores = targets["saliency_scores"]
309
+ if saliency_scores.sum() == 0:
310
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
311
+
312
+ # * inter-vid mode
313
+ vid_mem_proj = outputs["vid_mem_proj"]
314
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
315
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
316
+
317
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
318
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
319
+ sim = sim_matrix(vid_feats, txt_feats)
320
+
321
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
322
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
323
+
324
+ # sum over positives
325
+ idiag = torch.diag(i_logsm)
326
+ jdiag = torch.diag(j_logsm)
327
+ loss_i = idiag.sum() / len(idiag)
328
+ loss_j = jdiag.sum() / len(jdiag)
329
+
330
+ loss_saliency_inter = - loss_i - loss_j
331
+
332
+ # * intra-vid mode
333
+ if 'cls_idx' not in targets.keys(): # eval
334
+ return {"loss_s_inter": loss_saliency_inter}
335
+
336
+ cls_indices = targets['cls_idx'].bool()
337
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
338
+ sim_cls = sim_matrix(vid_feats, cls_feats)
339
+
340
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
341
+ idiag_cls = i_logsm_cls[cls_indices]
342
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
343
+
344
+ loss_saliency_intra = - loss_cls_i
345
+
346
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
347
+
348
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
349
+ loss_map = {
350
+ "spans": self.loss_spans,
351
+ "labels": self.loss_labels,
352
+ "saliency": self.loss_saliency,
353
+ "saliency_cls": self.loss_saliency_cls,
354
+ }
355
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
356
+ return loss_map[loss](outputs, targets, indices, **kwargs)
357
+
358
+ def forward(self, outputs, targets, mask_GT=None):
359
+ """ This performs the loss computation.
360
+ Parameters:
361
+ outputs: dict of tensors, see the output specification of the model for the format
362
+ targets: list of dicts, such that len(targets) == batch_size.
363
+ The expected keys in each dict depends on the losses applied, see each loss' doc
364
+ """
365
+ indices = None
366
+ # Compute all the requested losses
367
+ losses = {}
368
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
369
+ count = mask_GT.sum()
370
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
371
+ # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
372
+ targets['saliency_scores'] = targets['saliency_scores'][0,:count]
373
+
374
+ for loss in self.losses:
375
+ losses.update(self.get_loss(loss, outputs, targets, indices))
376
+
377
+ return losses
378
+
379
+ class MLP(nn.Module):
380
+ """ Very simple multi-layer perceptron (also called FFN)"""
381
+
382
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
383
+ super().__init__()
384
+ self.num_layers = num_layers
385
+ h = [hidden_dim] * (num_layers - 1)
386
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
387
+
388
+ def forward(self, x):
389
+ for i, layer in enumerate(self.layers):
390
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
391
+ return x
392
+
393
+ class Conv(nn.Module):
394
+ """ Very simple multi-layer perceptron (also called FFN)"""
395
+
396
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
397
+ super().__init__()
398
+ self.num_layers = num_layers
399
+ h = [hidden_dim] * (num_layers - 1)
400
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
401
+ self.layers = nn.ModuleList(
402
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
403
+ for n, k in zip([input_dim] + h, h + [output_dim]))
404
+ def forward(self, x):
405
+ x = x.permute(0,2,1)
406
+ for i, layer in enumerate(self.layers):
407
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
408
+ return x.permute(0, 2, 1)
409
+
410
+ class LinearLayer(nn.Module):
411
+ """linear layer configurable with layer normalization, dropout, ReLU."""
412
+
413
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
414
+ super(LinearLayer, self).__init__()
415
+ self.relu = relu
416
+ self.layer_norm = layer_norm
417
+ if layer_norm:
418
+ self.LayerNorm = nn.LayerNorm(in_hsz)
419
+ layers = [
420
+ nn.Dropout(dropout),
421
+ nn.Linear(in_hsz, out_hsz)
422
+ ]
423
+ self.net = nn.Sequential(*layers)
424
+
425
+ def forward(self, x):
426
+ """(N, L, D)"""
427
+ if self.layer_norm:
428
+ x = self.LayerNorm(x)
429
+ x = self.net(x)
430
+ if self.relu:
431
+ x = F.relu(x, inplace=True)
432
+ return x # (N, L, D)
433
+
434
+
435
+ def build_model(args):
436
+ device = torch.device(args.device)
437
+
438
+ transformer = build_transformer(args)
439
+ position_embedding, txt_position_embedding = build_position_encoding(args)
440
+
441
+ model = Model(
442
+ transformer,
443
+ position_embedding,
444
+ txt_position_embedding,
445
+ txt_dim=args.t_feat_dim,
446
+ vid_dim=args.v_feat_dim,
447
+ input_dropout=args.input_dropout,
448
+ span_loss_type=args.span_loss_type,
449
+ use_txt_pos=args.use_txt_pos,
450
+ n_input_proj=args.n_input_proj,
451
+ )
452
+
453
+ matcher = build_matcher(args)
454
+ weight_dict = {"loss_b": args.b_loss_coef,
455
+ "loss_g": args.g_loss_coef,
456
+ "loss_f": args.f_loss_coef,
457
+ "loss_s_intra": args.s_loss_intra_coef,
458
+ "loss_s_inter": args.s_loss_inter_coef}
459
+
460
+ if args.dset_type in ['mr', 'vlp']:
461
+ if 'tal' not in args.train_path:
462
+ losses = ['spans', 'labels', 'saliency']
463
+ else:
464
+ losses = ['spans', 'labels', 'saliency_cls']
465
+ elif args.dset_type in ['hl', 'vs']:
466
+ losses = ['labels', 'saliency']
467
+
468
+ criterion = SetCriterion(
469
+ matcher=matcher,
470
+ weight_dict=weight_dict, losses=losses,
471
+ eos_coef=args.eos_coef, temperature=args.temperature,
472
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
473
+ saliency_margin=args.saliency_margin,
474
+ )
475
+ criterion.to(device)
476
+ return model, criterion
model/base_prompt.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.prompt_learner = nn.Embedding(10, hidden_dim)
81
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
82
+ self.token_type_embeddings.apply(init_weights)
83
+
84
+ # Conv projector
85
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
86
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
87
+
88
+ self.use_txt_pos = use_txt_pos
89
+ self.n_input_proj = n_input_proj
90
+ relu_args = [True] * 3
91
+ relu_args[n_input_proj-1] = False
92
+ self.input_txt_proj = nn.Sequential(*[
93
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
95
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
96
+ ][:n_input_proj])
97
+ self.input_vid_proj = nn.Sequential(*[
98
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
100
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
101
+ ][:n_input_proj])
102
+
103
+ # MLP Projector
104
+ self.weightedpool = WeightedPool(hidden_dim)
105
+
106
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
107
+ bs = src_vid.shape[0]
108
+ src_vid = self.input_vid_proj(src_vid)
109
+ src_txt = self.input_txt_proj(src_txt)
110
+ if src_cls is not None:
111
+ src_cls = self.input_txt_proj(src_cls)
112
+
113
+ src_prompt = self.prompt_learner.weight.unsqueeze(0).repeat(bs, 1, 1)
114
+ src_prompt_mask = torch.ones((bs, src_prompt.shape[1])).cuda()
115
+
116
+ if self.training:
117
+ # src_txt = src_prompt
118
+ # src_txt_mask = torch.ones_like(src_prompt).cuda()
119
+ src_txt = torch.cat([src_prompt, src_txt], dim=1)
120
+ src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
121
+ else:
122
+ src_txt = torch.cat([src_prompt, src_txt], dim=1)
123
+ src_txt_mask = torch.cat([src_prompt_mask, src_txt_mask], dim=1)
124
+
125
+ # type token.
126
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
127
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
128
+ if src_cls is not None:
129
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
130
+
131
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
132
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
133
+
134
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
135
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
136
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
137
+
138
+ memory = self.transformer(src, ~mask, pos)
139
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
140
+
141
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
142
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
143
+
144
+ if self.span_loss_type == "l1":
145
+ outputs_coord = outputs_coord.sigmoid()
146
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
147
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
148
+ outputs_coord = outputs_coord * idx_mask
149
+ else:
150
+ raise NotImplementedError
151
+
152
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
153
+ 'src_vid_mask': src_vid_mask}
154
+
155
+ vid_mem_proj = src_vid
156
+
157
+ # word-level -> sentence-level
158
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
159
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
160
+
161
+ out["vid_mem_proj"] = vid_mem_proj
162
+ out["txt_mem_proj"] = txt_mem_proj
163
+ if src_cls is not None:
164
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
165
+ out["cls_mem_proj"] = cls_mem_proj
166
+ out["saliency_scores"] = sim
167
+ return out
168
+
169
+ class SetCriterion(nn.Module):
170
+ """ This class computes the loss for DETR.
171
+ The process happens in two steps:
172
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
173
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
174
+ """
175
+
176
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
177
+ saliency_margin=1):
178
+ """ Create the criterion.
179
+ Parameters:
180
+ matcher: module able to compute a matching between targets and proposals
181
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
182
+ eos_coef: relative classification weight applied to the no-object category
183
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
184
+ temperature: float, temperature for NCE loss
185
+ span_loss_type: str, [l1, ce]
186
+ max_v_l: int,
187
+ saliency_margin: float
188
+ """
189
+ super().__init__()
190
+ self.matcher = matcher
191
+ self.weight_dict = weight_dict
192
+ self.losses = losses
193
+ self.temperature = temperature
194
+ self.span_loss_type = span_loss_type
195
+ self.max_v_l = max_v_l
196
+ self.saliency_margin = saliency_margin
197
+ self.temperature = 0.07
198
+
199
+ # foreground and background classification
200
+ self.foreground_label = 0
201
+ self.background_label = 1
202
+ self.eos_coef = eos_coef
203
+ empty_weight = torch.ones(2)
204
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
205
+ self.register_buffer('empty_weight', empty_weight)
206
+
207
+ def loss_spans(self, outputs, targets, indices):
208
+ assert 'pred_spans' in outputs
209
+
210
+ start_spans = targets['timestamp']
211
+ pred_spans = outputs['pred_spans']
212
+ src_spans = start_spans + pred_spans
213
+ gt_spans = targets['span_labels_nn']
214
+
215
+ mask = targets['timestamp_mask'].bool()
216
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
217
+ mask_valid = targets['timestamp_window'].bool()
218
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
219
+
220
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
221
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
222
+
223
+ losses = {}
224
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
225
+ losses['loss_g'] = loss_giou.mean()
226
+ return losses
227
+
228
+ def loss_labels(self, outputs, targets, indices, log=True):
229
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
230
+ mask = targets['timestamp_mask'].bool()
231
+ mask_valid = targets['timestamp_window'].bool()
232
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
233
+ target_classes[mask_valid] = 1
234
+ # target_classes = targets['timestamp_window'] # soft cls.
235
+ target_classes.float()
236
+ # pdb.set_trace()
237
+
238
+ weights = torch.zeros_like(target_classes).float()
239
+ weights[mask] = self.empty_weight[1]
240
+ weights[mask_valid] = self.empty_weight[0]
241
+
242
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
243
+ return {"loss_f": loss_ce.sum() / mask.sum()}
244
+
245
+ def loss_saliency(self, outputs, targets, indices, log=True):
246
+ """higher scores for positive clips"""
247
+ if "saliency_pos_labels" not in targets:
248
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
249
+ saliency_scores = targets["saliency_scores"]
250
+ if saliency_scores.sum() == 0:
251
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
252
+
253
+ # * inter-vid mode
254
+ vid_mem_proj = outputs["vid_mem_proj"]
255
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
256
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
257
+
258
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
259
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
260
+ sim = sim_matrix(vid_feats, txt_feats)
261
+
262
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
263
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
264
+
265
+ # sum over positives
266
+ idiag = torch.diag(i_logsm)
267
+ jdiag = torch.diag(j_logsm)
268
+ loss_i = idiag.sum() / len(idiag)
269
+ loss_j = jdiag.sum() / len(jdiag)
270
+
271
+ loss_saliency_inter = - loss_i - loss_j
272
+
273
+ # * intra-vid mode
274
+ mask = targets['timestamp_mask']
275
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
276
+ neg_indices_in = (saliency_scores < selected_scores)
277
+ neg_indices_in[batch_indices, pos_indices] = True
278
+ mask_invalid = neg_indices_in * mask.bool()
279
+
280
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
281
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
282
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
283
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
284
+
285
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
286
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
287
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
288
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
289
+
290
+ loss_saliency_intra = - loss_in_i - loss_in_j
291
+
292
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
293
+
294
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
295
+ """higher scores for positive clips"""
296
+ if "saliency_pos_labels" not in targets:
297
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
298
+ saliency_scores = targets["saliency_scores"]
299
+ if saliency_scores.sum() == 0:
300
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
301
+
302
+ # * inter-vid mode
303
+ vid_mem_proj = outputs["vid_mem_proj"]
304
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
305
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
306
+
307
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
308
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
309
+ sim = sim_matrix(vid_feats, txt_feats)
310
+
311
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
312
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
313
+
314
+ # sum over positives
315
+ idiag = torch.diag(i_logsm)
316
+ jdiag = torch.diag(j_logsm)
317
+ loss_i = idiag.sum() / len(idiag)
318
+ loss_j = jdiag.sum() / len(jdiag)
319
+
320
+ loss_saliency_inter = - loss_i - loss_j
321
+
322
+ # * intra-vid mode
323
+ if 'cls_idx' not in targets.keys(): # eval
324
+ return {"loss_s_inter": loss_saliency_inter}
325
+
326
+ cls_indices = targets['cls_idx'].bool()
327
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
328
+ sim_cls = sim_matrix(vid_feats, cls_feats)
329
+
330
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
331
+ idiag_cls = i_logsm_cls[cls_indices]
332
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
333
+
334
+ loss_saliency_intra = - loss_cls_i
335
+
336
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
337
+
338
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
339
+ loss_map = {
340
+ "spans": self.loss_spans,
341
+ "labels": self.loss_labels,
342
+ "saliency": self.loss_saliency,
343
+ "saliency_cls": self.loss_saliency_cls,
344
+ }
345
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
346
+ return loss_map[loss](outputs, targets, indices, **kwargs)
347
+
348
+ def forward(self, outputs, targets, hl_only=False):
349
+ """ This performs the loss computation.
350
+ Parameters:
351
+ outputs: dict of tensors, see the output specification of the model for the format
352
+ targets: list of dicts, such that len(targets) == batch_size.
353
+ The expected keys in each dict depends on the losses applied, see each loss' doc
354
+ """
355
+ indices = None
356
+ # Compute all the requested losses
357
+ losses = {}
358
+ for loss in self.losses:
359
+ losses.update(self.get_loss(loss, outputs, targets, indices))
360
+
361
+ return losses
362
+
363
+ class MLP(nn.Module):
364
+ """ Very simple multi-layer perceptron (also called FFN)"""
365
+
366
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
367
+ super().__init__()
368
+ self.num_layers = num_layers
369
+ h = [hidden_dim] * (num_layers - 1)
370
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
371
+
372
+ def forward(self, x):
373
+ for i, layer in enumerate(self.layers):
374
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
375
+ return x
376
+
377
+ class Conv(nn.Module):
378
+ """ Very simple multi-layer perceptron (also called FFN)"""
379
+
380
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
381
+ super().__init__()
382
+ self.num_layers = num_layers
383
+ h = [hidden_dim] * (num_layers - 1)
384
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
385
+ self.layers = nn.ModuleList(
386
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
387
+ for n, k in zip([input_dim] + h, h + [output_dim]))
388
+ def forward(self, x):
389
+ x = x.permute(0,2,1)
390
+ for i, layer in enumerate(self.layers):
391
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
392
+ return x.permute(0, 2, 1)
393
+
394
+ class LinearLayer(nn.Module):
395
+ """linear layer configurable with layer normalization, dropout, ReLU."""
396
+
397
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
398
+ super(LinearLayer, self).__init__()
399
+ self.relu = relu
400
+ self.layer_norm = layer_norm
401
+ if layer_norm:
402
+ self.LayerNorm = nn.LayerNorm(in_hsz)
403
+ layers = [
404
+ nn.Dropout(dropout),
405
+ nn.Linear(in_hsz, out_hsz)
406
+ ]
407
+ self.net = nn.Sequential(*layers)
408
+
409
+ def forward(self, x):
410
+ """(N, L, D)"""
411
+ if self.layer_norm:
412
+ x = self.LayerNorm(x)
413
+ x = self.net(x)
414
+ if self.relu:
415
+ x = F.relu(x, inplace=True)
416
+ return x # (N, L, D)
417
+
418
+
419
+ def build_model(args):
420
+ device = torch.device(args.device)
421
+
422
+ transformer = build_transformer(args)
423
+ position_embedding, txt_position_embedding = build_position_encoding(args)
424
+
425
+ model = Model(
426
+ transformer,
427
+ position_embedding,
428
+ txt_position_embedding,
429
+ txt_dim=args.t_feat_dim,
430
+ vid_dim=args.v_feat_dim,
431
+ input_dropout=args.input_dropout,
432
+ span_loss_type=args.span_loss_type,
433
+ use_txt_pos=args.use_txt_pos,
434
+ n_input_proj=args.n_input_proj,
435
+ )
436
+
437
+ matcher = build_matcher(args)
438
+ weight_dict = {"loss_b": args.b_loss_coef,
439
+ "loss_g": args.g_loss_coef,
440
+ "loss_f": args.f_loss_coef,
441
+ "loss_s_intra": args.s_loss_intra_coef,
442
+ "loss_s_inter": args.s_loss_inter_coef}
443
+
444
+ if args.dset_type in ['mr']:
445
+ if 'tal' not in args.train_path:
446
+ losses = ['spans', 'labels', 'saliency']
447
+ else:
448
+ losses = ['spans', 'labels', 'saliency_cls']
449
+ elif args.dset_type in ['hl', 'vs']:
450
+ losses = ['labels', 'saliency']
451
+
452
+ criterion = SetCriterion(
453
+ matcher=matcher,
454
+ weight_dict=weight_dict, losses=losses,
455
+ eos_coef=args.eos_coef, temperature=args.temperature,
456
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
457
+ saliency_margin=args.saliency_margin,
458
+ )
459
+ criterion.to(device)
460
+ return model, criterion
model/base_qfvs.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
208
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
209
+
210
+ losses = {}
211
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
212
+ losses['loss_g'] = loss_giou.mean()
213
+ return losses
214
+
215
+ def loss_labels(self, outputs, targets, indices, log=True):
216
+ saliency_scores = targets["saliency_scores"]
217
+ if saliency_scores.sum() == 0:
218
+ return {"loss_f": 0.}
219
+
220
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
221
+ target_classes = targets["saliency_scores"].squeeze()
222
+
223
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
224
+ weights[target_classes.bool()] = self.empty_weight[0]
225
+
226
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
227
+ # pdb.set_trace()
228
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
229
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
230
+
231
+ # mask = targets['timestamp_mask'].bool()
232
+ # mask_valid = targets['timestamp_window'].bool()
233
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
234
+ # target_classes[mask_valid] = 1
235
+ # # target_classes = targets['timestamp_window'] # soft cls.
236
+ # target_classes.float()
237
+ # # pdb.set_trace()
238
+
239
+ # weights = torch.zeros_like(target_classes).float()
240
+ # weights[mask] = self.empty_weight[1]
241
+ # weights[mask_valid] = self.empty_weight[0]
242
+
243
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
244
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
245
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
246
+
247
+ def loss_saliency(self, outputs, targets, indices, log=True):
248
+ """higher scores for positive clips"""
249
+ if "saliency_pos_labels" not in targets:
250
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
251
+ saliency_scores = targets["saliency_scores"]
252
+ if saliency_scores.sum() == 0:
253
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
254
+
255
+ # * qfvs mil-nce mode
256
+ pos_indices = saliency_scores.squeeze() > 0
257
+
258
+ sim = outputs['saliency_scores']
259
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
260
+ sim_log = torch.log(sim_soft[pos_indices])
261
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
262
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
263
+
264
+ # * inter-vid mode
265
+ # vid_mem_proj = outputs["vid_mem_proj"]
266
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
267
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
268
+
269
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
270
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
271
+ # sim = sim_matrix(vid_feats, txt_feats)
272
+
273
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
274
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
275
+
276
+ # # sum over positives
277
+ # idiag = torch.diag(i_logsm)
278
+ # jdiag = torch.diag(j_logsm)
279
+ # loss_i = idiag.sum() / len(idiag)
280
+ # loss_j = jdiag.sum() / len(jdiag)
281
+
282
+ # loss_saliency_inter = - loss_i - loss_j
283
+
284
+ # # * intra-vid mode
285
+ # mask = targets['timestamp_mask']
286
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
287
+ # neg_indices_in = (saliency_scores < selected_scores)
288
+ # neg_indices_in[batch_indices, pos_indices] = True
289
+ # mask_invalid = neg_indices_in * mask.bool()
290
+
291
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
292
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
293
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
294
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
295
+
296
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
297
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
298
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
299
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
300
+
301
+ # loss_saliency_intra = - loss_in_i - loss_in_j
302
+
303
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
304
+
305
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
306
+ """higher scores for positive clips"""
307
+ if "saliency_pos_labels" not in targets:
308
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
309
+ saliency_scores = targets["saliency_scores"]
310
+ if saliency_scores.sum() == 0:
311
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
312
+
313
+ # * inter-vid mode
314
+ vid_mem_proj = outputs["vid_mem_proj"]
315
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
316
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
317
+
318
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
319
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
320
+ sim = sim_matrix(vid_feats, txt_feats)
321
+
322
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
323
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
324
+
325
+ # sum over positives
326
+ idiag = torch.diag(i_logsm)
327
+ jdiag = torch.diag(j_logsm)
328
+ loss_i = idiag.sum() / len(idiag)
329
+ loss_j = jdiag.sum() / len(jdiag)
330
+
331
+ loss_saliency_inter = - loss_i - loss_j
332
+
333
+ # * intra-vid mode
334
+ if 'cls_idx' not in targets.keys(): # eval
335
+ return {"loss_s_inter": loss_saliency_inter}
336
+
337
+ cls_indices = targets['cls_idx'].bool()
338
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
339
+ sim_cls = sim_matrix(vid_feats, cls_feats)
340
+
341
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
342
+ idiag_cls = i_logsm_cls[cls_indices]
343
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
344
+
345
+ loss_saliency_intra = - loss_cls_i
346
+
347
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
348
+
349
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
350
+ loss_map = {
351
+ "spans": self.loss_spans,
352
+ "labels": self.loss_labels,
353
+ "saliency": self.loss_saliency,
354
+ "saliency_cls": self.loss_saliency_cls,
355
+ }
356
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
357
+ return loss_map[loss](outputs, targets, indices, **kwargs)
358
+
359
+ def forward(self, outputs, targets, mask_GT=None):
360
+ """ This performs the loss computation.
361
+ Parameters:
362
+ outputs: dict of tensors, see the output specification of the model for the format
363
+ targets: list of dicts, such that len(targets) == batch_size.
364
+ The expected keys in each dict depends on the losses applied, see each loss' doc
365
+ """
366
+ indices = None
367
+ # Compute all the requested losses
368
+ losses = {}
369
+ # pdb.set_trace()
370
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
371
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
372
+ targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
373
+
374
+ for loss in self.losses:
375
+ losses.update(self.get_loss(loss, outputs, targets, indices))
376
+
377
+ return losses
378
+
379
+ class MLP(nn.Module):
380
+ """ Very simple multi-layer perceptron (also called FFN)"""
381
+
382
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
383
+ super().__init__()
384
+ self.num_layers = num_layers
385
+ h = [hidden_dim] * (num_layers - 1)
386
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
387
+
388
+ def forward(self, x):
389
+ for i, layer in enumerate(self.layers):
390
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
391
+ return x
392
+
393
+ class Conv(nn.Module):
394
+ """ Very simple multi-layer perceptron (also called FFN)"""
395
+
396
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
397
+ super().__init__()
398
+ self.num_layers = num_layers
399
+ h = [hidden_dim] * (num_layers - 1)
400
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
401
+ self.layers = nn.ModuleList(
402
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
403
+ for n, k in zip([input_dim] + h, h + [output_dim]))
404
+ def forward(self, x):
405
+ x = x.permute(0,2,1)
406
+ for i, layer in enumerate(self.layers):
407
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
408
+ return x.permute(0, 2, 1)
409
+
410
+ class LinearLayer(nn.Module):
411
+ """linear layer configurable with layer normalization, dropout, ReLU."""
412
+
413
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
414
+ super(LinearLayer, self).__init__()
415
+ self.relu = relu
416
+ self.layer_norm = layer_norm
417
+ if layer_norm:
418
+ self.LayerNorm = nn.LayerNorm(in_hsz)
419
+ layers = [
420
+ nn.Dropout(dropout),
421
+ nn.Linear(in_hsz, out_hsz)
422
+ ]
423
+ self.net = nn.Sequential(*layers)
424
+
425
+ def forward(self, x):
426
+ """(N, L, D)"""
427
+ if self.layer_norm:
428
+ x = self.LayerNorm(x)
429
+ x = self.net(x)
430
+ if self.relu:
431
+ x = F.relu(x, inplace=True)
432
+ return x # (N, L, D)
433
+
434
+
435
+ def build_model(args):
436
+ device = torch.device(args.device)
437
+
438
+ transformer = build_transformer(args)
439
+ position_embedding, txt_position_embedding = build_position_encoding(args)
440
+
441
+ model = Model(
442
+ transformer,
443
+ position_embedding,
444
+ txt_position_embedding,
445
+ txt_dim=args.t_feat_dim,
446
+ vid_dim=args.v_feat_dim,
447
+ input_dropout=args.input_dropout,
448
+ span_loss_type=args.span_loss_type,
449
+ use_txt_pos=args.use_txt_pos,
450
+ n_input_proj=args.n_input_proj,
451
+ )
452
+
453
+ matcher = build_matcher(args)
454
+ weight_dict = {"loss_b": args.b_loss_coef,
455
+ "loss_g": args.g_loss_coef,
456
+ "loss_f": args.f_loss_coef,
457
+ "loss_s_intra": args.s_loss_intra_coef,
458
+ "loss_s_inter": args.s_loss_inter_coef}
459
+
460
+ if args.dset_type in ['mr', 'vlp']:
461
+ if 'tal' not in args.train_path:
462
+ losses = ['spans', 'labels', 'saliency']
463
+ else:
464
+ losses = ['spans', 'labels', 'saliency_cls']
465
+ elif args.dset_type in ['hl', 'vs']:
466
+ losses = ['labels', 'saliency']
467
+
468
+ criterion = SetCriterion(
469
+ matcher=matcher,
470
+ weight_dict=weight_dict, losses=losses,
471
+ eos_coef=args.eos_coef, temperature=args.temperature,
472
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
473
+ saliency_margin=args.saliency_margin,
474
+ )
475
+ criterion.to(device)
476
+ return model, criterion
model/matcher.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Modules to compute the matching cost and solve the corresponding LSAP.
4
+ """
5
+ import torch
6
+ from scipy.optimize import linear_sum_assignment
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
10
+
11
+
12
+ class HungarianMatcher(nn.Module):
13
+ """This class computes an assignment between the targets and the predictions of the network
14
+
15
+ For efficiency reasons, the targets don't include the no_object. Because of this, in general,
16
+ there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
17
+ while the others are un-matched (and thus treated as non-objects).
18
+ """
19
+ def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1,
20
+ span_loss_type: str = "l1", max_v_l: int = 75):
21
+ """Creates the matcher
22
+
23
+ Params:
24
+ cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost
25
+ cost_giou: This is the relative weight of the giou loss of the spans in the matching cost
26
+ """
27
+ super().__init__()
28
+ self.cost_class = cost_class
29
+ self.cost_span = cost_span
30
+ self.cost_giou = cost_giou
31
+ self.span_loss_type = span_loss_type
32
+ self.max_v_l = max_v_l
33
+ self.foreground_label = 0
34
+ assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0"
35
+
36
+ @torch.no_grad()
37
+ def forward(self, outputs, targets):
38
+ """ Performs the matching
39
+
40
+ Params:
41
+ outputs: This is a dict that contains at least these entries:
42
+ "pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates,
43
+ in normalized (cx, w) format
44
+ ""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
45
+
46
+ targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
47
+ "spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are
48
+ in normalized (cx, w) format
49
+
50
+ Returns:
51
+ A list of size batch_size, containing tuples of (index_i, index_j) where:
52
+ - index_i is the indices of the selected predictions (in order)
53
+ - index_j is the indices of the corresponding selected targets (in order)
54
+ For each batch element, it holds:
55
+ len(index_i) = len(index_j) = min(num_queries, num_target_spans)
56
+ """
57
+ bs, num_queries = outputs["pred_spans"].shape[:2]
58
+ targets = targets["span_labels"]
59
+
60
+ # Also concat the target labels and spans
61
+ out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
62
+ tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2]
63
+ tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch]
64
+
65
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL,
66
+ # but approximate it in 1 - prob[target class].
67
+ # The 1 is a constant that doesn't change the matching, it can be omitted.
68
+ cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch]
69
+
70
+ if self.span_loss_type == "l1":
71
+ # We flatten to compute the cost matrices in a batch
72
+ out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2]
73
+
74
+ # Compute the L1 cost between spans
75
+ cost_span = torch.cdist(out_spans, tgt_spans, p=1) # [batch_size * num_queries, total #spans in the batch]
76
+
77
+ # Compute the giou cost between spans
78
+ # [batch_size * num_queries, total #spans in the batch]
79
+ cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans))
80
+ else:
81
+ pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2)
82
+ pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l)
83
+ cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \
84
+ pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans)
85
+ # pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2)
86
+ # tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2)
87
+ # cost_span = pred_spans[tgt_spans]
88
+ # cost_span = cost_span.view(bs * num_queries, n_spans)
89
+
90
+ # giou
91
+ cost_giou = 0
92
+
93
+ # Final cost matrix
94
+ # import ipdb; ipdb.set_trace()
95
+ C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class
96
+ C = C.view(bs, num_queries, -1).cpu()
97
+
98
+ sizes = [len(v["spans"]) for v in targets]
99
+ indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
100
+ return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
101
+
102
+
103
+ def build_matcher(args):
104
+ return HungarianMatcher(
105
+ cost_span=args.set_cost_span, cost_giou=args.set_cost_giou,
106
+ cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l
107
+ )
model/moment_detr.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR model and criterion classes.
4
+ """
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
10
+
11
+ from model.transformer import build_transformer
12
+ from model.matcher import build_matcher
13
+ from model.position_encoding import build_position_encoding
14
+
15
+ @torch.no_grad()
16
+ def accuracy(output, target, topk=(1,)):
17
+ """Computes the precision@k for the specified values of k
18
+ output: (#items, #classes)
19
+ target: int,
20
+ """
21
+ maxk = max(topk)
22
+ num_items = output.size(0)
23
+
24
+ _, pred = output.topk(maxk, 1, True, True)
25
+ pred = pred.t()
26
+ correct = pred.eq(target)
27
+
28
+ res = []
29
+ for k in topk:
30
+ correct_k = correct[:k].view(-1).float().sum(0)
31
+ res.append(correct_k.mul_(100.0 / num_items))
32
+ return res
33
+
34
+ class Model(nn.Module):
35
+ """ This is the Moment-DETR module that performs moment localization. """
36
+
37
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
38
+ num_queries, input_dropout, aux_loss=False,
39
+ contrastive_align_loss=False, contrastive_hdim=64,
40
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
41
+ """ Initializes the model.
42
+ Parameters:
43
+ transformer: torch module of the transformer architecture. See transformer.py
44
+ position_embed: torch module of the position_embedding, See position_encoding.py
45
+ txt_position_embed: position_embedding for text
46
+ txt_dim: int, text query input dimension
47
+ vid_dim: int, video feature input dimension
48
+ num_queries: number of object queries, ie detection slot. This is the maximal number of objects
49
+ Moment-DETR can detect in a single video.
50
+ aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
51
+ contrastive_align_loss: If true, perform span - tokens contrastive learning
52
+ contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
53
+ max_v_l: int, maximum #clips in videos
54
+ span_loss_type: str, one of [l1, ce]
55
+ l1: (center-x, width) regression.
56
+ ce: (st_idx, ed_idx) classification.
57
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
58
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
59
+ """
60
+ super().__init__()
61
+ self.num_queries = num_queries
62
+ self.transformer = transformer
63
+ self.position_embed = position_embed
64
+ self.txt_position_embed = txt_position_embed
65
+ hidden_dim = transformer.d_model
66
+ self.span_loss_type = span_loss_type
67
+ self.max_v_l = max_v_l
68
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
69
+ self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
70
+ self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
71
+ self.use_txt_pos = use_txt_pos
72
+ self.n_input_proj = n_input_proj
73
+ # self.foreground_thd = foreground_thd
74
+ # self.background_thd = background_thd
75
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
76
+ relu_args = [True] * 3
77
+ relu_args[n_input_proj-1] = False
78
+ self.input_txt_proj = nn.Sequential(*[
79
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
80
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
81
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
82
+ ][:n_input_proj])
83
+ self.input_vid_proj = nn.Sequential(*[
84
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
85
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
86
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
87
+ ][:n_input_proj])
88
+ self.contrastive_align_loss = contrastive_align_loss
89
+ if contrastive_align_loss:
90
+ self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
91
+ self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
92
+ self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
93
+
94
+ self.saliency_proj = nn.Linear(hidden_dim, 1)
95
+ self.aux_loss = aux_loss
96
+
97
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask):
98
+ """The forward expects two tensors:
99
+ - src_txt: [batch_size, L_txt, D_txt]
100
+ - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
101
+ will convert to 1 as padding later for transformer
102
+ - src_vid: [batch_size, L_vid, D_vid]
103
+ - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
104
+ will convert to 1 as padding later for transformer
105
+
106
+ It returns a dict with the following elements:
107
+ - "pred_spans": The normalized boxes coordinates for all queries, represented as
108
+ (center_x, width). These values are normalized in [0, 1],
109
+ relative to the size of each individual image (disregarding possible padding).
110
+ See PostProcess for information on how to retrieve the unnormalized bounding box.
111
+ - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
112
+ dictionnaries containing the two above keys for each decoder layer.
113
+ """
114
+ src_vid = self.input_vid_proj(src_vid)
115
+ src_txt = self.input_txt_proj(src_txt)
116
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
117
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
118
+ # TODO should we remove or use different positional embeddings to the src_txt?
119
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
120
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
121
+ # pos_txt = torch.zeros_like(src_txt)
122
+ # pad zeros for txt positions
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+ # (#layers, bsz, #queries, d), (bsz, L_vid+L_txt, d)
125
+ hs, memory = self.transformer(src, ~mask, self.query_embed.weight, pos)
126
+ outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
127
+ outputs_coord = self.span_embed(hs) # (#layers, bsz, #queries, 2 or max_v_l * 2)
128
+ if self.span_loss_type == "l1":
129
+ outputs_coord = outputs_coord.sigmoid()
130
+ out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
131
+
132
+ txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
133
+ vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
134
+ if self.contrastive_align_loss:
135
+ proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
136
+ proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
137
+ proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
138
+ out.update(dict(
139
+ proj_queries=proj_queries[-1],
140
+ proj_txt_mem=proj_txt_mem,
141
+ proj_vid_mem=proj_vid_mem
142
+ ))
143
+
144
+ out["saliency_scores"] = self.saliency_proj(vid_mem).squeeze(-1) # (bsz, L_vid)
145
+
146
+ if self.aux_loss:
147
+ # assert proj_queries and proj_txt_mem
148
+ out['aux_outputs'] = [
149
+ {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
150
+ if self.contrastive_align_loss:
151
+ assert proj_queries is not None
152
+ for idx, d in enumerate(proj_queries[:-1]):
153
+ out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
154
+ return out
155
+
156
+ # @torch.jit.unused
157
+ # def _set_aux_loss(self, outputs_class, outputs_coord):
158
+ # # this is a workaround to make torchscript happy, as torchscript
159
+ # # doesn't support dictionary with non-homogeneous values, such
160
+ # # as a dict having both a Tensor and a list.
161
+ # return [{'pred_logits': a, 'pred_spans': b}
162
+ # for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
163
+
164
+
165
+ class SetCriterion(nn.Module):
166
+ """ This class computes the loss for DETR.
167
+ The process happens in two steps:
168
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
169
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
170
+ """
171
+
172
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
173
+ saliency_margin=1):
174
+ """ Create the criterion.
175
+ Parameters:
176
+ matcher: module able to compute a matching between targets and proposals
177
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
178
+ eos_coef: relative classification weight applied to the no-object category
179
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
180
+ temperature: float, temperature for NCE loss
181
+ span_loss_type: str, [l1, ce]
182
+ max_v_l: int,
183
+ saliency_margin: float
184
+ """
185
+ super().__init__()
186
+ self.matcher = matcher
187
+ self.weight_dict = weight_dict
188
+ self.losses = losses
189
+ self.temperature = temperature
190
+ self.span_loss_type = span_loss_type
191
+ self.max_v_l = max_v_l
192
+ self.saliency_margin = saliency_margin
193
+
194
+ # foreground and background classification
195
+ self.foreground_label = 0
196
+ self.background_label = 1
197
+ self.eos_coef = eos_coef
198
+ empty_weight = torch.ones(2)
199
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
200
+ self.register_buffer('empty_weight', empty_weight)
201
+
202
+ def loss_spans(self, outputs, targets, indices):
203
+ """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
204
+ targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
205
+ The target spans are expected in format (center_x, w), normalized by the image size.
206
+ """
207
+ assert 'pred_spans' in outputs
208
+ targets = targets["span_labels"]
209
+ idx = self._get_src_permutation_idx(indices)
210
+ src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
211
+ tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
212
+ if self.span_loss_type == "l1":
213
+ loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
214
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
215
+ else: # ce
216
+ n_spans = src_spans.shape[0]
217
+ src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
218
+ loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
219
+
220
+ # giou
221
+ # src_span_indices = src_spans.max(1)[1] # (#spans, 2)
222
+ # src_span_indices[:, 1] += 1 # ed non-inclusive [st, ed)
223
+ #
224
+ # tgt_span_indices = tgt_spans
225
+ # tgt_span_indices[:, 1] += 1
226
+ # loss_giou = 1 - torch.diag(generalized_temporal_iou(src_span_indices, tgt_span_indices))
227
+ loss_giou = loss_span.new_zeros([1])
228
+
229
+ losses = {}
230
+ losses['loss_b'] = loss_span.mean()
231
+ losses['loss_g'] = loss_giou.mean()
232
+ return losses
233
+
234
+ def loss_labels(self, outputs, targets, indices, log=True):
235
+ """Classification loss (NLL)
236
+ targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
237
+ """
238
+ # TODO add foreground and background classifier. use all non-matched as background.
239
+ assert 'pred_logits' in outputs
240
+ src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
241
+ # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
242
+ idx = self._get_src_permutation_idx(indices)
243
+ target_classes = torch.full(src_logits.shape[:2], self.background_label,
244
+ dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
245
+ target_classes[idx] = self.foreground_label
246
+
247
+ loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
248
+ losses = {'loss_f': loss_ce.mean()}
249
+
250
+ if log:
251
+ # TODO this should probably be a separate loss, not hacked in this one here
252
+ losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
253
+ return losses
254
+
255
+ def loss_saliency(self, outputs, targets, indices, log=True):
256
+ """higher scores for positive clips"""
257
+ if "saliency_pos_labels" not in targets:
258
+ return {"loss_s_intra": 0}
259
+ saliency_scores = outputs["saliency_scores"] # (N, L)
260
+ pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
261
+ neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
262
+ num_pairs = pos_indices.shape[1] # typically 2 or 4
263
+ batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
264
+ pos_scores = torch.stack(
265
+ [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
266
+ neg_scores = torch.stack(
267
+ [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
268
+ loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
269
+ / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
270
+ return {"loss_s_intra": loss_saliency}
271
+
272
+ def loss_contrastive_align(self, outputs, targets, indices, log=True):
273
+ """encourage higher scores between matched query span and input text"""
274
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
275
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
276
+ logits = torch.einsum(
277
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
278
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
279
+ idx = self._get_src_permutation_idx(indices)
280
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
281
+ positive_map[idx] = True
282
+ positive_logits = logits.masked_fill(~positive_map, 0)
283
+
284
+ pos_term = positive_logits.sum(1) # (bsz, )
285
+ num_pos = positive_map.sum(1) # (bsz, )
286
+ neg_term = logits.logsumexp(1) # (bsz, )
287
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
288
+ losses = {"loss_contrastive_align": loss_nce.mean()}
289
+ return losses
290
+
291
+ def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
292
+ """encourage higher scores between matched query span and input text"""
293
+ # TODO (1) align vid_mem and txt_mem;
294
+ # TODO (2) change L1 loss as CE loss on 75 labels, similar to soft token prediction in MDETR
295
+ normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
296
+ normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
297
+ logits = torch.einsum(
298
+ "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
299
+ logits = logits.sum(2) / self.temperature # (bsz, #queries)
300
+ idx = self._get_src_permutation_idx(indices)
301
+ positive_map = torch.zeros_like(logits, dtype=torch.bool)
302
+ positive_map[idx] = True
303
+ positive_logits = logits.masked_fill(~positive_map, 0)
304
+
305
+ pos_term = positive_logits.sum(1) # (bsz, )
306
+ num_pos = positive_map.sum(1) # (bsz, )
307
+ neg_term = logits.logsumexp(1) # (bsz, )
308
+ loss_nce = - pos_term / num_pos + neg_term # (bsz, )
309
+ losses = {"loss_contrastive_align": loss_nce.mean()}
310
+ return losses
311
+
312
+ def _get_src_permutation_idx(self, indices):
313
+ # permute predictions following indices
314
+ batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
315
+ src_idx = torch.cat([src for (src, _) in indices])
316
+ return batch_idx, src_idx # two 1D tensors of the same length
317
+
318
+ def _get_tgt_permutation_idx(self, indices):
319
+ # permute targets following indices
320
+ batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
321
+ tgt_idx = torch.cat([tgt for (_, tgt) in indices])
322
+ return batch_idx, tgt_idx
323
+
324
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
325
+ loss_map = {
326
+ "spans": self.loss_spans,
327
+ "labels": self.loss_labels,
328
+ "contrastive_align": self.loss_contrastive_align,
329
+ "saliency": self.loss_saliency,
330
+ }
331
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
332
+ return loss_map[loss](outputs, targets, indices, **kwargs)
333
+
334
+ def forward(self, outputs, targets):
335
+ """ This performs the loss computation.
336
+ Parameters:
337
+ outputs: dict of tensors, see the output specification of the model for the format
338
+ targets: list of dicts, such that len(targets) == batch_size.
339
+ The expected keys in each dict depends on the losses applied, see each loss' doc
340
+ """
341
+ outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
342
+
343
+ # Retrieve the matching between the outputs of the last layer and the targets
344
+ # list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
345
+ indices = self.matcher(outputs_without_aux, targets)
346
+
347
+ # Compute all the requested losses
348
+ losses = {}
349
+ for loss in self.losses:
350
+ losses.update(self.get_loss(loss, outputs, targets, indices))
351
+
352
+ # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
353
+ if 'aux_outputs' in outputs:
354
+ for i, aux_outputs in enumerate(outputs['aux_outputs']):
355
+ indices = self.matcher(aux_outputs, targets)
356
+ for loss in self.losses:
357
+ if "saliency" == loss: # skip as it is only in the top layer
358
+ continue
359
+ kwargs = {}
360
+ l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
361
+ l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
362
+ losses.update(l_dict)
363
+
364
+ return losses
365
+
366
+
367
+ class MLP(nn.Module):
368
+ """ Very simple multi-layer perceptron (also called FFN)"""
369
+
370
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
371
+ super().__init__()
372
+ self.num_layers = num_layers
373
+ h = [hidden_dim] * (num_layers - 1)
374
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
375
+
376
+ def forward(self, x):
377
+ for i, layer in enumerate(self.layers):
378
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
379
+ return x
380
+
381
+
382
+ class LinearLayer(nn.Module):
383
+ """linear layer configurable with layer normalization, dropout, ReLU."""
384
+
385
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
386
+ super(LinearLayer, self).__init__()
387
+ self.relu = relu
388
+ self.layer_norm = layer_norm
389
+ if layer_norm:
390
+ self.LayerNorm = nn.LayerNorm(in_hsz)
391
+ layers = [
392
+ nn.Dropout(dropout),
393
+ nn.Linear(in_hsz, out_hsz)
394
+ ]
395
+ self.net = nn.Sequential(*layers)
396
+
397
+ def forward(self, x):
398
+ """(N, L, D)"""
399
+ if self.layer_norm:
400
+ x = self.LayerNorm(x)
401
+ x = self.net(x)
402
+ if self.relu:
403
+ x = F.relu(x, inplace=True)
404
+ return x # (N, L, D)
405
+
406
+
407
+ def build_model(args):
408
+ # the `num_classes` naming here is somewhat misleading.
409
+ # it indeed corresponds to `max_obj_id + 1`, where max_obj_id
410
+ # is the maximum id for a class in your dataset. For example,
411
+ # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
412
+ # As another example, for a dataset that has a single class with id 1,
413
+ # you should pass `num_classes` to be 2 (max_obj_id + 1).
414
+ # For more details on this, check the following discussion
415
+ # https://github.com/facebookresearch/moment_bert/issues/108#issuecomment-650269223
416
+ device = torch.device(args.device)
417
+
418
+ transformer = build_transformer(args)
419
+ position_embedding, txt_position_embedding = build_position_encoding(args)
420
+
421
+ model = Model(
422
+ transformer,
423
+ position_embedding,
424
+ txt_position_embedding,
425
+ txt_dim=args.t_feat_dim,
426
+ vid_dim=args.v_feat_dim,
427
+ num_queries=args.num_queries,
428
+ input_dropout=args.input_dropout,
429
+ aux_loss=args.aux_loss,
430
+ # contrastive_align_loss=args.contrastive_align_loss,
431
+ # contrastive_hdim=args.contrastive_hdim,
432
+ span_loss_type=args.span_loss_type,
433
+ use_txt_pos=args.use_txt_pos,
434
+ n_input_proj=args.n_input_proj,
435
+ )
436
+
437
+ matcher = build_matcher(args)
438
+ weight_dict = {"loss_b": args.b_loss_coef,
439
+ "loss_g": args.g_loss_coef,
440
+ "loss_f": args.f_loss_coef,
441
+ "loss_s_intra": args.s_loss_intra_coef,
442
+ "loss_s_inter": args.s_loss_inter_coef}
443
+ # if args.contrastive_align_loss:
444
+ # weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
445
+ # TODO this is a hack
446
+ if args.aux_loss:
447
+ aux_weight_dict = {}
448
+ for i in range(args.dec_layers - 1):
449
+ aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
450
+ weight_dict.update(aux_weight_dict)
451
+
452
+ losses = ['spans', 'labels', 'saliency']
453
+ # if args.contrastive_align_loss:
454
+ # losses += ["contrastive_align"]
455
+ criterion = SetCriterion(
456
+ matcher=matcher, weight_dict=weight_dict, losses=losses,
457
+ eos_coef=args.eos_coef, temperature=args.temperature,
458
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
459
+ saliency_margin=args.saliency_margin
460
+ )
461
+ criterion.to(device)
462
+ return model, criterion
model/position_encoding.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ import numpy as np
9
+
10
+ def PositionalEncoding(n_position, d_hid):
11
+ def get_position_angle_vec(position, d_hid):
12
+ return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
13
+
14
+ sinusoid_table = np.array([get_position_angle_vec(pos_i, d_hid) for pos_i in range(n_position)])
15
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
16
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
17
+ return torch.FloatTensor(sinusoid_table) # shape:(1, maxLen(n_position), d_hid)
18
+
19
+ class TrainablePositionalEncoding(nn.Module):
20
+ """Construct the embeddings from word, position and token_type embeddings.
21
+ """
22
+ def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
23
+ super(TrainablePositionalEncoding, self).__init__()
24
+ self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
25
+ self.LayerNorm = nn.LayerNorm(hidden_size)
26
+ self.dropout = nn.Dropout(dropout)
27
+
28
+ def forward(self, input_feat):
29
+ """
30
+ Args:
31
+ input_feat: (N, L, D)
32
+ """
33
+ bsz, seq_length = input_feat.shape[:2]
34
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
35
+ position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
36
+
37
+ position_embeddings = self.position_embeddings(position_ids)
38
+
39
+ embeddings = self.LayerNorm(input_feat + position_embeddings)
40
+ embeddings = self.dropout(embeddings)
41
+ return embeddings
42
+
43
+
44
+ class PositionEmbeddingSine(nn.Module):
45
+ """
46
+ This is a more standard version of the position embedding, very similar to the one
47
+ used by the Attention is all you need paper, generalized to work on images. (To 1D sequences)
48
+ """
49
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
50
+ super().__init__()
51
+ self.num_pos_feats = num_pos_feats
52
+ self.temperature = temperature
53
+ self.normalize = normalize
54
+ if scale is not None and normalize is False:
55
+ raise ValueError("normalize should be True if scale is passed")
56
+ if scale is None:
57
+ scale = 2 * math.pi
58
+ self.scale = scale
59
+
60
+ def forward(self, x, mask):
61
+ """
62
+ Args:
63
+ x: torch.tensor, (batch_size, L, d)
64
+ mask: torch.tensor, (batch_size, L), with 1 as valid
65
+
66
+ Returns:
67
+
68
+ """
69
+ assert mask is not None
70
+ x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L)
71
+ if self.normalize:
72
+ eps = 1e-6
73
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
74
+
75
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
76
+ # import pdb; pdb.set_trace()
77
+ # dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
78
+ dim_t = self.temperature ** (2 * torch.div(dim_t, 2).int() / self.num_pos_feats)
79
+
80
+ pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats)
81
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2)
82
+ # import ipdb; ipdb.set_trace()
83
+ return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L)
84
+
85
+
86
+ class PositionEmbeddingLearned(nn.Module):
87
+ """
88
+ Absolute pos embedding, learned.
89
+ """
90
+ def __init__(self, num_pos_feats=256):
91
+ super().__init__()
92
+ self.row_embed = nn.Embedding(50, num_pos_feats)
93
+ self.col_embed = nn.Embedding(50, num_pos_feats)
94
+ self.reset_parameters()
95
+
96
+ def reset_parameters(self):
97
+ nn.init.uniform_(self.row_embed.weight)
98
+ nn.init.uniform_(self.col_embed.weight)
99
+
100
+ def forward(self, x, mask):
101
+ h, w = x.shape[-2:]
102
+ i = torch.arange(w, device=x.device)
103
+ j = torch.arange(h, device=x.device)
104
+ x_emb = self.col_embed(i)
105
+ y_emb = self.row_embed(j)
106
+ pos = torch.cat([
107
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
108
+ y_emb.unsqueeze(1).repeat(1, w, 1),
109
+ ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
110
+ return pos
111
+
112
+
113
+ def build_position_encoding(args):
114
+ N_steps = args.hidden_dim
115
+ if args.position_embedding in ('v2', 'sine'):
116
+ # TODO find a better way of exposing other arguments
117
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
118
+ # elif args.position_embedding in ('v3', 'learned'):
119
+ # position_embedding = PositionEmbeddingLearned(N_steps)
120
+ else:
121
+ raise ValueError(f"not supported {args.position_embedding}")
122
+
123
+ txt_pos_embed = TrainablePositionalEncoding(
124
+ max_position_embeddings=args.max_q_l,
125
+ hidden_size=args.hidden_dim, dropout=args.input_dropout)
126
+ return position_embedding, txt_pos_embed
model/transformer.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ DETR Transformer class.
4
+
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn, Tensor
16
+
17
+
18
+ class Transformer(nn.Module):
19
+
20
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
21
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
22
+ activation="relu", normalize_before=False,
23
+ return_intermediate_dec=False):
24
+ super().__init__()
25
+
26
+ # TransformerEncoderLayerThin
27
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
28
+ dropout, activation, normalize_before)
29
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
30
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
31
+
32
+ # TransformerDecoderLayerThin
33
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
34
+ dropout, activation, normalize_before)
35
+ decoder_norm = nn.LayerNorm(d_model)
36
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
37
+ return_intermediate=return_intermediate_dec)
38
+
39
+ self._reset_parameters()
40
+
41
+ self.d_model = d_model
42
+ self.nhead = nhead
43
+
44
+ def _reset_parameters(self):
45
+ for p in self.parameters():
46
+ if p.dim() > 1:
47
+ nn.init.xavier_uniform_(p)
48
+
49
+ def forward(self, src, mask, query_embed, pos_embed):
50
+ """
51
+ Args:
52
+ src: (batch_size, L, d)
53
+ mask: (batch_size, L)
54
+ query_embed: (#queries, d)
55
+ pos_embed: (batch_size, L, d) the same as src
56
+
57
+ Returns:
58
+
59
+ """
60
+ # flatten NxCxHxW to HWxNxC
61
+ bs, l, d = src.shape
62
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
63
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
64
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d)
65
+
66
+ tgt = torch.zeros_like(query_embed)
67
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d)
68
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
69
+ pos=pos_embed, query_pos=query_embed) # (#layers, #queries, batch_size, d)
70
+ hs = hs.transpose(1, 2) # (#layers, batch_size, #qeries, d)
71
+ # memory = memory.permute(1, 2, 0) # (batch_size, d, L)
72
+ memory = memory.transpose(0, 1) # (batch_size, L, d)
73
+ return hs, memory
74
+
75
+
76
+ class TransformerEncoder(nn.Module):
77
+
78
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+ self.return_intermediate = return_intermediate
84
+
85
+ def forward(self, src,
86
+ mask: Optional[Tensor] = None,
87
+ src_key_padding_mask: Optional[Tensor] = None,
88
+ pos: Optional[Tensor] = None):
89
+ output = src
90
+
91
+ intermediate = []
92
+
93
+ for layer in self.layers:
94
+ output = layer(output, src_mask=mask,
95
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
96
+ if self.return_intermediate:
97
+ intermediate.append(output)
98
+
99
+ if self.norm is not None:
100
+ output = self.norm(output)
101
+
102
+ if self.return_intermediate:
103
+ return torch.stack(intermediate)
104
+
105
+ return output
106
+
107
+
108
+ class TransformerDecoder(nn.Module):
109
+
110
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
111
+ super().__init__()
112
+ self.layers = _get_clones(decoder_layer, num_layers)
113
+ self.num_layers = num_layers
114
+ self.norm = norm
115
+ self.return_intermediate = return_intermediate
116
+
117
+ def forward(self, tgt, memory,
118
+ tgt_mask: Optional[Tensor] = None,
119
+ memory_mask: Optional[Tensor] = None,
120
+ tgt_key_padding_mask: Optional[Tensor] = None,
121
+ memory_key_padding_mask: Optional[Tensor] = None,
122
+ pos: Optional[Tensor] = None,
123
+ query_pos: Optional[Tensor] = None):
124
+ output = tgt
125
+
126
+ intermediate = []
127
+
128
+ for layer in self.layers:
129
+ output = layer(output, memory, tgt_mask=tgt_mask,
130
+ memory_mask=memory_mask,
131
+ tgt_key_padding_mask=tgt_key_padding_mask,
132
+ memory_key_padding_mask=memory_key_padding_mask,
133
+ pos=pos, query_pos=query_pos)
134
+ if self.return_intermediate:
135
+ intermediate.append(self.norm(output))
136
+
137
+ if self.norm is not None:
138
+ output = self.norm(output)
139
+ if self.return_intermediate:
140
+ intermediate.pop()
141
+ intermediate.append(output)
142
+
143
+ if self.return_intermediate:
144
+ return torch.stack(intermediate)
145
+
146
+ return output.unsqueeze(0)
147
+
148
+
149
+ class TransformerEncoderLayerThin(nn.Module):
150
+
151
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
152
+ activation="relu", normalize_before=False):
153
+ super().__init__()
154
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
155
+ # Implementation of Feedforward model
156
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
157
+ # self.dropout = nn.Dropout(dropout)
158
+ # self.linear2 = nn.Linear(dim_feedforward, d_model)
159
+ self.linear = nn.Linear(d_model, d_model)
160
+ self.norm = nn.LayerNorm(d_model)
161
+ self.dropout = nn.Dropout(dropout)
162
+
163
+ # self.activation = _get_activation_fn(activation)
164
+ self.normalize_before = normalize_before
165
+
166
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
167
+ return tensor if pos is None else tensor + pos
168
+
169
+ def forward_post(self,
170
+ src,
171
+ src_mask: Optional[Tensor] = None,
172
+ src_key_padding_mask: Optional[Tensor] = None,
173
+ pos: Optional[Tensor] = None):
174
+ q = k = self.with_pos_embed(src, pos)
175
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
176
+ key_padding_mask=src_key_padding_mask)[0]
177
+ src2 = self.linear(src2)
178
+ src = src + self.dropout(src2)
179
+ src = self.norm(src)
180
+ # src = src + self.dropout1(src2)
181
+ # src = self.norm1(src)
182
+ # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
183
+ # src = src + self.dropout2(src2)
184
+ # src = self.norm2(src)
185
+ return src
186
+
187
+ def forward_pre(self, src,
188
+ src_mask: Optional[Tensor] = None,
189
+ src_key_padding_mask: Optional[Tensor] = None,
190
+ pos: Optional[Tensor] = None):
191
+ """not used"""
192
+ src2 = self.norm1(src)
193
+ q = k = self.with_pos_embed(src2, pos)
194
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
195
+ key_padding_mask=src_key_padding_mask)[0]
196
+ src = src + self.dropout1(src2)
197
+ src2 = self.norm2(src)
198
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
199
+ src = src + self.dropout2(src2)
200
+ return src
201
+
202
+ def forward(self, src,
203
+ src_mask: Optional[Tensor] = None,
204
+ src_key_padding_mask: Optional[Tensor] = None,
205
+ pos: Optional[Tensor] = None):
206
+ if self.normalize_before:
207
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
208
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
209
+
210
+
211
+ class TransformerEncoderLayer(nn.Module):
212
+
213
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
214
+ activation="relu", normalize_before=False):
215
+ super().__init__()
216
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
217
+ # Implementation of Feedforward model
218
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
219
+ self.dropout = nn.Dropout(dropout)
220
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
221
+
222
+ self.norm1 = nn.LayerNorm(d_model)
223
+ self.norm2 = nn.LayerNorm(d_model)
224
+ self.dropout1 = nn.Dropout(dropout)
225
+ self.dropout2 = nn.Dropout(dropout)
226
+
227
+ self.activation = _get_activation_fn(activation)
228
+ self.normalize_before = normalize_before
229
+
230
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
231
+ return tensor if pos is None else tensor + pos
232
+
233
+ def forward_post(self,
234
+ src,
235
+ src_mask: Optional[Tensor] = None,
236
+ src_key_padding_mask: Optional[Tensor] = None,
237
+ pos: Optional[Tensor] = None):
238
+ q = k = self.with_pos_embed(src, pos)
239
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
240
+ key_padding_mask=src_key_padding_mask)[0]
241
+ src = src + self.dropout1(src2)
242
+ src = self.norm1(src)
243
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
244
+ src = src + self.dropout2(src2)
245
+ src = self.norm2(src)
246
+ return src
247
+
248
+ def forward_pre(self, src,
249
+ src_mask: Optional[Tensor] = None,
250
+ src_key_padding_mask: Optional[Tensor] = None,
251
+ pos: Optional[Tensor] = None):
252
+ src2 = self.norm1(src)
253
+ q = k = self.with_pos_embed(src2, pos)
254
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
255
+ key_padding_mask=src_key_padding_mask)[0]
256
+ src = src + self.dropout1(src2)
257
+ src2 = self.norm2(src)
258
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
259
+ src = src + self.dropout2(src2)
260
+ return src
261
+
262
+ def forward(self, src,
263
+ src_mask: Optional[Tensor] = None,
264
+ src_key_padding_mask: Optional[Tensor] = None,
265
+ pos: Optional[Tensor] = None):
266
+ if self.normalize_before:
267
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
268
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
269
+
270
+
271
+ class TransformerDecoderLayer(nn.Module):
272
+
273
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
274
+ activation="relu", normalize_before=False):
275
+ super().__init__()
276
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
277
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
278
+ # Implementation of Feedforward model
279
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
280
+ self.dropout = nn.Dropout(dropout)
281
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
282
+
283
+ self.norm1 = nn.LayerNorm(d_model)
284
+ self.norm2 = nn.LayerNorm(d_model)
285
+ self.norm3 = nn.LayerNorm(d_model)
286
+ self.dropout1 = nn.Dropout(dropout)
287
+ self.dropout2 = nn.Dropout(dropout)
288
+ self.dropout3 = nn.Dropout(dropout)
289
+
290
+ self.activation = _get_activation_fn(activation)
291
+ self.normalize_before = normalize_before
292
+
293
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
294
+ return tensor if pos is None else tensor + pos
295
+
296
+ def forward_post(self, tgt, memory,
297
+ tgt_mask: Optional[Tensor] = None,
298
+ memory_mask: Optional[Tensor] = None,
299
+ tgt_key_padding_mask: Optional[Tensor] = None,
300
+ memory_key_padding_mask: Optional[Tensor] = None,
301
+ pos: Optional[Tensor] = None,
302
+ query_pos: Optional[Tensor] = None):
303
+ q = k = self.with_pos_embed(tgt, query_pos)
304
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
305
+ key_padding_mask=tgt_key_padding_mask)[0]
306
+ tgt = tgt + self.dropout1(tgt2)
307
+ tgt = self.norm1(tgt)
308
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
309
+ key=self.with_pos_embed(memory, pos),
310
+ value=memory, attn_mask=memory_mask,
311
+ key_padding_mask=memory_key_padding_mask)[0]
312
+ tgt = tgt + self.dropout2(tgt2)
313
+ tgt = self.norm2(tgt)
314
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
315
+ tgt = tgt + self.dropout3(tgt2)
316
+ tgt = self.norm3(tgt)
317
+ return tgt
318
+
319
+ def forward_pre(self, tgt, memory,
320
+ tgt_mask: Optional[Tensor] = None,
321
+ memory_mask: Optional[Tensor] = None,
322
+ tgt_key_padding_mask: Optional[Tensor] = None,
323
+ memory_key_padding_mask: Optional[Tensor] = None,
324
+ pos: Optional[Tensor] = None,
325
+ query_pos: Optional[Tensor] = None):
326
+ tgt2 = self.norm1(tgt)
327
+ q = k = self.with_pos_embed(tgt2, query_pos)
328
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
329
+ key_padding_mask=tgt_key_padding_mask)[0]
330
+ tgt = tgt + self.dropout1(tgt2)
331
+ tgt2 = self.norm2(tgt)
332
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
333
+ key=self.with_pos_embed(memory, pos),
334
+ value=memory, attn_mask=memory_mask,
335
+ key_padding_mask=memory_key_padding_mask)[0]
336
+ tgt = tgt + self.dropout2(tgt2)
337
+ tgt2 = self.norm3(tgt)
338
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
339
+ tgt = tgt + self.dropout3(tgt2)
340
+ return tgt
341
+
342
+ def forward(self, tgt, memory,
343
+ tgt_mask: Optional[Tensor] = None,
344
+ memory_mask: Optional[Tensor] = None,
345
+ tgt_key_padding_mask: Optional[Tensor] = None,
346
+ memory_key_padding_mask: Optional[Tensor] = None,
347
+ pos: Optional[Tensor] = None,
348
+ query_pos: Optional[Tensor] = None):
349
+ if self.normalize_before:
350
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
351
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
352
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
353
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
354
+
355
+
356
+ class TransformerDecoderLayerThin(nn.Module):
357
+ """removed intermediate layer"""
358
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
359
+ activation="relu", normalize_before=False):
360
+ super().__init__()
361
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
362
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
363
+ # Implementation of Feedforward model
364
+ self.linear1 = nn.Linear(d_model, d_model)
365
+ # self.linear1 = nn.Linear(d_model, dim_feedforward)
366
+ # self.dropout = nn.Dropout(dropout)
367
+ # self.linear2 = nn.Linear(dim_feedforward, d_model)
368
+
369
+ self.norm1 = nn.LayerNorm(d_model)
370
+ self.norm2 = nn.LayerNorm(d_model)
371
+ # self.norm3 = nn.LayerNorm(d_model)
372
+ self.dropout1 = nn.Dropout(dropout)
373
+ self.dropout2 = nn.Dropout(dropout)
374
+ # self.dropout3 = nn.Dropout(dropout)
375
+
376
+ # self.activation = _get_activation_fn(activation)
377
+ self.normalize_before = normalize_before
378
+
379
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
380
+ return tensor if pos is None else tensor + pos
381
+
382
+ def forward_post(self, tgt, memory,
383
+ tgt_mask: Optional[Tensor] = None,
384
+ memory_mask: Optional[Tensor] = None,
385
+ tgt_key_padding_mask: Optional[Tensor] = None,
386
+ memory_key_padding_mask: Optional[Tensor] = None,
387
+ pos: Optional[Tensor] = None,
388
+ query_pos: Optional[Tensor] = None):
389
+ q = k = self.with_pos_embed(tgt, query_pos)
390
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
391
+ key_padding_mask=tgt_key_padding_mask)[0]
392
+ tgt = tgt + self.dropout1(tgt2)
393
+ tgt = self.norm1(tgt)
394
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
395
+ key=self.with_pos_embed(memory, pos),
396
+ value=memory, attn_mask=memory_mask,
397
+ key_padding_mask=memory_key_padding_mask)[0]
398
+ tgt2 = self.linear1(tgt2)
399
+ tgt = tgt + self.dropout2(tgt2)
400
+ tgt = self.norm2(tgt)
401
+ # tgt = tgt + self.dropout2(tgt2)
402
+ # tgt = self.norm2(tgt)
403
+ # tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
404
+ # tgt = tgt + self.dropout3(tgt2)
405
+ # tgt = self.norm3(tgt)
406
+ return tgt
407
+
408
+ def forward_pre(self, tgt, memory,
409
+ tgt_mask: Optional[Tensor] = None,
410
+ memory_mask: Optional[Tensor] = None,
411
+ tgt_key_padding_mask: Optional[Tensor] = None,
412
+ memory_key_padding_mask: Optional[Tensor] = None,
413
+ pos: Optional[Tensor] = None,
414
+ query_pos: Optional[Tensor] = None):
415
+ tgt2 = self.norm1(tgt)
416
+ q = k = self.with_pos_embed(tgt2, query_pos)
417
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
418
+ key_padding_mask=tgt_key_padding_mask)[0]
419
+ tgt = tgt + self.dropout1(tgt2)
420
+ tgt2 = self.norm2(tgt)
421
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
422
+ key=self.with_pos_embed(memory, pos),
423
+ value=memory, attn_mask=memory_mask,
424
+ key_padding_mask=memory_key_padding_mask)[0]
425
+ tgt = tgt + self.dropout2(tgt2)
426
+ tgt2 = self.norm3(tgt)
427
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
428
+ tgt = tgt + self.dropout3(tgt2)
429
+ return tgt
430
+
431
+ def forward(self, tgt, memory,
432
+ tgt_mask: Optional[Tensor] = None,
433
+ memory_mask: Optional[Tensor] = None,
434
+ tgt_key_padding_mask: Optional[Tensor] = None,
435
+ memory_key_padding_mask: Optional[Tensor] = None,
436
+ pos: Optional[Tensor] = None,
437
+ query_pos: Optional[Tensor] = None):
438
+ if self.normalize_before:
439
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
440
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
441
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
442
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
443
+
444
+
445
+
446
+ def _get_clones(module, N):
447
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
448
+
449
+
450
+ def build_transformer(args):
451
+ return Transformer(
452
+ d_model=args.hidden_dim,
453
+ dropout=args.dropout,
454
+ nhead=args.nheads,
455
+ dim_feedforward=args.dim_feedforward,
456
+ num_encoder_layers=args.enc_layers,
457
+ num_decoder_layers=args.dec_layers,
458
+ normalize_before=args.pre_norm,
459
+ return_intermediate_dec=True,
460
+ )
461
+
462
+
463
+ def _get_activation_fn(activation):
464
+ """Return an activation function given a string"""
465
+ if activation == "relu":
466
+ return F.relu
467
+ if activation == "gelu":
468
+ return F.gelu
469
+ if activation == "glu":
470
+ return F.glu
471
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
model/transformer_encoder.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import pdb
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+
9
+ def mask_logits(inputs, mask, mask_value=-1e30):
10
+ mask = mask.type(torch.float32)
11
+ return inputs + (1.0 - mask) * mask_value
12
+
13
+
14
+ class Transformer(nn.Module):
15
+
16
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
17
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
18
+ activation="relu", normalize_before=False, # False as default
19
+ return_intermediate_dec=False):
20
+ super().__init__()
21
+
22
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
23
+ dropout, activation, normalize_before)
24
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
25
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
26
+
27
+ self._reset_parameters()
28
+
29
+ self.d_model = d_model
30
+ self.nhead = nhead
31
+
32
+ def _reset_parameters(self):
33
+ for p in self.parameters():
34
+ if p.dim() > 1:
35
+ nn.init.xavier_uniform_(p)
36
+
37
+ def forward(self, src, mask, pos_embed):
38
+ """
39
+ Args:
40
+ src: (batch_size, L, d)
41
+ mask: (batch_size, L)
42
+ query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
43
+ pos_embed: (batch_size, L, d) the same as src
44
+
45
+ Returns:
46
+
47
+ """
48
+ # flatten NxCxHxW to HWxNxC
49
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
50
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
51
+
52
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
53
+ memory = memory.transpose(0, 1)
54
+
55
+ return memory
56
+
57
+
58
+ class TransformerEncoder(nn.Module):
59
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
60
+ super().__init__()
61
+ self.layers = _get_clones(encoder_layer, num_layers)
62
+ self.num_layers = num_layers
63
+ self.norm = norm
64
+ self.return_intermediate = return_intermediate
65
+
66
+ def forward(self, src,
67
+ mask: Optional[Tensor] = None,
68
+ src_key_padding_mask: Optional[Tensor] = None,
69
+ pos: Optional[Tensor] = None):
70
+ output = src
71
+
72
+ intermediate = []
73
+
74
+ for layer in self.layers:
75
+ output = layer(output, src_mask=mask,
76
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
77
+ if self.return_intermediate:
78
+ intermediate.append(output)
79
+
80
+ if self.norm is not None:
81
+ output = self.norm(output)
82
+
83
+ if self.return_intermediate:
84
+ return torch.stack(intermediate)
85
+
86
+ return output
87
+
88
+
89
+ class TransformerEncoderLayer(nn.Module):
90
+
91
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
92
+ activation="relu", normalize_before=False):
93
+ super().__init__()
94
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
95
+ # Implementation of Feedforward model
96
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
97
+ self.dropout = nn.Dropout(dropout)
98
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
99
+
100
+ self.norm1 = nn.LayerNorm(d_model)
101
+ self.norm2 = nn.LayerNorm(d_model)
102
+ self.dropout1 = nn.Dropout(dropout)
103
+ self.dropout2 = nn.Dropout(dropout)
104
+
105
+ self.activation = _get_activation_fn(activation)
106
+ self.normalize_before = normalize_before
107
+
108
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
109
+ return tensor if pos is None else tensor + pos
110
+
111
+ def forward_post(self,
112
+ src,
113
+ src_mask: Optional[Tensor] = None,
114
+ src_key_padding_mask: Optional[Tensor] = None,
115
+ pos: Optional[Tensor] = None):
116
+ q = k = self.with_pos_embed(src, pos)
117
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
118
+ key_padding_mask=src_key_padding_mask)[0]
119
+ src = src + self.dropout1(src2)
120
+ src = self.norm1(src)
121
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
122
+ src = src + self.dropout2(src2)
123
+ src = self.norm2(src)
124
+ return src
125
+
126
+ def forward(self, src,
127
+ src_mask: Optional[Tensor] = None,
128
+ src_key_padding_mask: Optional[Tensor] = None,
129
+ pos: Optional[Tensor] = None):
130
+ if self.normalize_before:
131
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
132
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
133
+
134
+
135
+ def _get_clones(module, N):
136
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
137
+
138
+
139
+ def build_transformer(args):
140
+ return Transformer(
141
+ d_model=args.hidden_dim,
142
+ dropout=args.dropout,
143
+ nhead=args.nheads,
144
+ dim_feedforward=args.dim_feedforward,
145
+ num_encoder_layers=args.enc_layers,
146
+ num_decoder_layers=args.dec_layers,
147
+ normalize_before=args.pre_norm,
148
+ return_intermediate_dec=True,
149
+ )
150
+
151
+ def _get_activation_fn(activation):
152
+ """Return an activation function given a string"""
153
+ if activation == "relu":
154
+ return F.relu
155
+ if activation == "gelu":
156
+ return F.gelu
157
+ if activation == "glu":
158
+ return F.glu
159
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
model/transformer_encoder_droppath.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import pdb
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn, Tensor
8
+
9
+ def mask_logits(inputs, mask, mask_value=-1e30):
10
+ mask = mask.type(torch.float32)
11
+ return inputs + (1.0 - mask) * mask_value
12
+
13
+
14
+ class Transformer(nn.Module):
15
+
16
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=4,
17
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, droppath=0.1,
18
+ activation="gelu", normalize_before=False, # False as default
19
+ return_intermediate_dec=False):
20
+ super().__init__()
21
+
22
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
23
+ dropout, droppath, activation, normalize_before)
24
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
25
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
26
+
27
+ self._reset_parameters()
28
+
29
+ self.d_model = d_model
30
+ self.nhead = nhead
31
+
32
+ def _reset_parameters(self):
33
+ for p in self.parameters():
34
+ if p.dim() > 1:
35
+ nn.init.xavier_uniform_(p)
36
+
37
+ def forward(self, src, mask, pos_embed):
38
+ """
39
+ Args:
40
+ src: (batch_size, L, d)
41
+ mask: (batch_size, L)
42
+ query_embed: (#queries, d) -> my imple (batch_size, d) and #queries=1
43
+ pos_embed: (batch_size, L, d) the same as src
44
+
45
+ Returns:
46
+
47
+ """
48
+ # flatten NxCxHxW to HWxNxC
49
+ src = src.permute(1, 0, 2) # (L, batch_size, d)
50
+ pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d)
51
+
52
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
53
+ memory = memory.transpose(0, 1)
54
+
55
+ return memory
56
+
57
+
58
+ class TransformerEncoder(nn.Module):
59
+ def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False):
60
+ super().__init__()
61
+ self.layers = _get_clones(encoder_layer, num_layers)
62
+ self.num_layers = num_layers
63
+ self.norm = norm
64
+ self.return_intermediate = return_intermediate
65
+
66
+ def forward(self, src,
67
+ mask: Optional[Tensor] = None,
68
+ src_key_padding_mask: Optional[Tensor] = None,
69
+ pos: Optional[Tensor] = None):
70
+ output = src
71
+
72
+ intermediate = []
73
+
74
+ for layer in self.layers:
75
+ output = layer(output, src_mask=mask,
76
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
77
+ if self.return_intermediate:
78
+ intermediate.append(output)
79
+
80
+ if self.norm is not None:
81
+ output = self.norm(output)
82
+
83
+ if self.return_intermediate:
84
+ return torch.stack(intermediate)
85
+
86
+ return output
87
+
88
+ class TransformerEncoderLayer(nn.Module):
89
+
90
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, droppath=0.1,
91
+ activation="relu", normalize_before=False):
92
+ super().__init__()
93
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
94
+ # Implementation of Feedforward model
95
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
96
+ self.dropout = nn.Dropout(dropout)
97
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
98
+
99
+ self.norm1 = nn.LayerNorm(d_model)
100
+ self.norm2 = nn.LayerNorm(d_model)
101
+ # self.dropout1 = nn.Dropout(dropout)
102
+ # self.dropout2 = nn.Dropout(dropout)
103
+ self.droppath1 = DropPath(droppath)
104
+ self.droppath2 = DropPath(droppath)
105
+
106
+ self.activation = _get_activation_fn(activation)
107
+ self.normalize_before = normalize_before
108
+
109
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
110
+ return tensor if pos is None else tensor + pos
111
+
112
+ def forward_post(self,
113
+ src,
114
+ src_mask: Optional[Tensor] = None,
115
+ src_key_padding_mask: Optional[Tensor] = None,
116
+ pos: Optional[Tensor] = None):
117
+ q = k = self.with_pos_embed(src, pos)
118
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
119
+ # src2 = self.self_attn_eff(q=q, k=k, v=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
120
+ src = src + self.droppath1(src2)
121
+ src = self.norm1(src)
122
+ src2 = self.linear2(self.activation(self.linear1(src)))
123
+ # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
124
+ src = src + self.droppath2(src2)
125
+ src = self.norm2(src)
126
+ return src
127
+
128
+ def forward(self, src,
129
+ src_mask: Optional[Tensor] = None,
130
+ src_key_padding_mask: Optional[Tensor] = None,
131
+ pos: Optional[Tensor] = None):
132
+ if self.normalize_before:
133
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
134
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
135
+
136
+
137
+ def _get_clones(module, N):
138
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
139
+
140
+
141
+ def build_transformer(args):
142
+ return Transformer(
143
+ d_model=args.hidden_dim,
144
+ dropout=args.dropout,
145
+ droppath=args.droppath,
146
+ nhead=args.nheads,
147
+ dim_feedforward=args.dim_feedforward,
148
+ num_encoder_layers=args.enc_layers,
149
+ num_decoder_layers=args.dec_layers,
150
+ normalize_before=args.pre_norm,
151
+ return_intermediate_dec=True,
152
+ )
153
+
154
+ def drop_path(x, drop_prob=0.0, training=False):
155
+ """
156
+ Stochastic Depth per sample.
157
+ """
158
+ if drop_prob == 0.0 or not training:
159
+ return x
160
+
161
+ keep_prob = 1 - drop_prob
162
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
163
+ mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
164
+ mask.floor_()
165
+ x = x.div(keep_prob) * mask
166
+
167
+ return x
168
+
169
+
170
+ class DropPath(nn.Module):
171
+ """
172
+ Drop paths per sample (when applied in main path of residual blocks).
173
+ """
174
+
175
+ def __init__(self, drop_prob=None):
176
+ super(DropPath, self).__init__()
177
+
178
+ self.drop_prob = drop_prob
179
+
180
+ def forward(self, x):
181
+ x = x.permute(1, 0, 2)
182
+ res = drop_path(x, self.drop_prob, self.training)
183
+ return res.permute(1, 0, 2)
184
+ # return drop_path(x, self.drop_prob, self.training)
185
+
186
+ def _get_activation_fn(activation):
187
+ """Return an activation function given a string"""
188
+ if activation == "relu":
189
+ return F.relu
190
+ if activation == "gelu":
191
+ return F.gelu
192
+ if activation == "glu":
193
+ return F.glu
194
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
model/univtg.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+ device_id = src_vid.device
112
+
113
+ # type token.
114
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
115
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
116
+ if src_cls is not None:
117
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
118
+
119
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
120
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
121
+
122
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
123
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
124
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
125
+
126
+ memory = self.transformer(src, ~mask, pos)
127
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
128
+
129
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
130
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
131
+
132
+ if self.span_loss_type == "l1":
133
+ outputs_coord = outputs_coord.sigmoid()
134
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).to(device_id)
135
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
136
+ outputs_coord = outputs_coord * idx_mask
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
141
+ 'src_vid_mask': src_vid_mask}
142
+
143
+ vid_mem_proj = src_vid
144
+
145
+ # word-level -> sentence-level
146
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
147
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
148
+
149
+ out["vid_mem_proj"] = vid_mem_proj
150
+ out["txt_mem_proj"] = txt_mem_proj
151
+ if src_cls is not None:
152
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
153
+ out["cls_mem_proj"] = cls_mem_proj
154
+ out["saliency_scores"] = sim
155
+ return out
156
+
157
+ class SetCriterion(nn.Module):
158
+ """ This class computes the loss for DETR.
159
+ The process happens in two steps:
160
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
161
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
162
+ """
163
+
164
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
165
+ saliency_margin=1):
166
+ """ Create the criterion.
167
+ Parameters:
168
+ matcher: module able to compute a matching between targets and proposals
169
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
170
+ eos_coef: relative classification weight applied to the no-object category
171
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
172
+ temperature: float, temperature for NCE loss
173
+ span_loss_type: str, [l1, ce]
174
+ max_v_l: int,
175
+ saliency_margin: float
176
+ """
177
+ super().__init__()
178
+ self.matcher = matcher
179
+ self.weight_dict = weight_dict
180
+ self.losses = losses
181
+ self.temperature = temperature
182
+ self.span_loss_type = span_loss_type
183
+ self.max_v_l = max_v_l
184
+ self.saliency_margin = saliency_margin
185
+ self.temperature = 0.07
186
+
187
+ # foreground and background classification
188
+ self.foreground_label = 0
189
+ self.background_label = 1
190
+ self.eos_coef = eos_coef
191
+ empty_weight = torch.ones(2)
192
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
193
+ self.register_buffer('empty_weight', empty_weight)
194
+
195
+ def loss_spans(self, outputs, targets, indices):
196
+ assert 'pred_spans' in outputs
197
+
198
+ start_spans = targets['timestamp']
199
+ pred_spans = outputs['pred_spans']
200
+ src_spans = start_spans + pred_spans
201
+ gt_spans = targets['span_labels_nn']
202
+
203
+ mask = targets['timestamp_mask'].bool()
204
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
205
+ mask_valid = targets['timestamp_window'].bool()
206
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
207
+
208
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
209
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
210
+
211
+ losses = {}
212
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
213
+ losses['loss_g'] = loss_giou.mean()
214
+ return losses
215
+
216
+ def loss_labels(self, outputs, targets, indices, log=True):
217
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
218
+ mask = targets['timestamp_mask'].bool()
219
+ mask_valid = targets['timestamp_window'].bool()
220
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
221
+ target_classes[mask_valid] = 1
222
+ # target_classes = targets['timestamp_window'] # soft cls.
223
+ target_classes.float()
224
+ # pdb.set_trace()
225
+
226
+ weights = torch.zeros_like(target_classes).float()
227
+ weights[mask] = self.empty_weight[1]
228
+ weights[mask_valid] = self.empty_weight[0]
229
+
230
+ # pdb.set_trace()
231
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
232
+ return {"loss_f": loss_ce.sum() / mask.sum()}
233
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
234
+
235
+ def loss_saliency(self, outputs, targets, indices, log=True):
236
+ """higher scores for positive clips"""
237
+ if "saliency_pos_labels" not in targets:
238
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
239
+ saliency_scores = targets["saliency_scores"]
240
+ if saliency_scores.sum() == 0:
241
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
242
+
243
+ # * inter-vid mode
244
+ vid_mem_proj = outputs["vid_mem_proj"]
245
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
246
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
247
+
248
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
249
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
250
+ sim = sim_matrix(vid_feats, txt_feats)
251
+
252
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
253
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
254
+
255
+ # sum over positives
256
+ idiag = torch.diag(i_logsm)
257
+ jdiag = torch.diag(j_logsm)
258
+ loss_i = idiag.sum() / len(idiag)
259
+ loss_j = jdiag.sum() / len(jdiag)
260
+
261
+ loss_saliency_inter = - loss_i - loss_j
262
+
263
+ # * intra-vid mode
264
+ mask = targets['timestamp_mask']
265
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
266
+ neg_indices_in = (saliency_scores < selected_scores)
267
+ neg_indices_in[batch_indices, pos_indices] = True
268
+ mask_invalid = neg_indices_in * mask.bool()
269
+
270
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
271
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
272
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
273
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
274
+
275
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
276
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
277
+ loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
278
+ loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
279
+
280
+ loss_saliency_intra = - loss_in_i - loss_in_j
281
+
282
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
283
+
284
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
285
+ """higher scores for positive clips"""
286
+ if "saliency_pos_labels" not in targets:
287
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
288
+ saliency_scores = targets["saliency_scores"]
289
+ if saliency_scores.sum() == 0:
290
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
291
+
292
+ # * inter-vid mode
293
+ vid_mem_proj = outputs["vid_mem_proj"]
294
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
295
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
296
+
297
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
298
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
299
+ sim = sim_matrix(vid_feats, txt_feats)
300
+
301
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
302
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
303
+
304
+ # sum over positives
305
+ idiag = torch.diag(i_logsm)
306
+ jdiag = torch.diag(j_logsm)
307
+ loss_i = idiag.sum() / len(idiag)
308
+ loss_j = jdiag.sum() / len(jdiag)
309
+
310
+ loss_saliency_inter = - loss_i - loss_j
311
+
312
+ # * intra-vid mode
313
+ if 'cls_idx' not in targets.keys(): # eval
314
+ return {"loss_s_inter": loss_saliency_inter}
315
+
316
+ cls_indices = targets['cls_idx'].bool()
317
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
318
+ sim_cls = sim_matrix(vid_feats, cls_feats)
319
+
320
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
321
+ idiag_cls = i_logsm_cls[cls_indices]
322
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
323
+
324
+ loss_saliency_intra = - loss_cls_i
325
+
326
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
327
+
328
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
329
+ loss_map = {
330
+ "spans": self.loss_spans,
331
+ "labels": self.loss_labels,
332
+ "saliency": self.loss_saliency,
333
+ "saliency_cls": self.loss_saliency_cls,
334
+ }
335
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
336
+ return loss_map[loss](outputs, targets, indices, **kwargs)
337
+
338
+ def forward(self, outputs, targets, hl_only=False):
339
+ """ This performs the loss computation.
340
+ Parameters:
341
+ outputs: dict of tensors, see the output specification of the model for the format
342
+ targets: list of dicts, such that len(targets) == batch_size.
343
+ The expected keys in each dict depends on the losses applied, see each loss' doc
344
+ """
345
+ indices = None
346
+ # Compute all the requested losses
347
+ losses = {}
348
+ for loss in self.losses:
349
+ losses.update(self.get_loss(loss, outputs, targets, indices))
350
+
351
+ return losses
352
+
353
+ class MLP(nn.Module):
354
+ """ Very simple multi-layer perceptron (also called FFN)"""
355
+
356
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
357
+ super().__init__()
358
+ self.num_layers = num_layers
359
+ h = [hidden_dim] * (num_layers - 1)
360
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
361
+
362
+ def forward(self, x):
363
+ for i, layer in enumerate(self.layers):
364
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
365
+ return x
366
+
367
+ class Conv(nn.Module):
368
+ """ Very simple multi-layer perceptron (also called FFN)"""
369
+
370
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
371
+ super().__init__()
372
+ self.num_layers = num_layers
373
+ h = [hidden_dim] * (num_layers - 1)
374
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
375
+ self.layers = nn.ModuleList(
376
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
377
+ for n, k in zip([input_dim] + h, h + [output_dim]))
378
+ def forward(self, x):
379
+ x = x.permute(0,2,1)
380
+ for i, layer in enumerate(self.layers):
381
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
382
+ return x.permute(0, 2, 1)
383
+
384
+ class LinearLayer(nn.Module):
385
+ """linear layer configurable with layer normalization, dropout, ReLU."""
386
+
387
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
388
+ super(LinearLayer, self).__init__()
389
+ self.relu = relu
390
+ self.layer_norm = layer_norm
391
+ if layer_norm:
392
+ self.LayerNorm = nn.LayerNorm(in_hsz)
393
+ layers = [
394
+ nn.Dropout(dropout),
395
+ nn.Linear(in_hsz, out_hsz)
396
+ ]
397
+ self.net = nn.Sequential(*layers)
398
+
399
+ def forward(self, x):
400
+ """(N, L, D)"""
401
+ if self.layer_norm:
402
+ x = self.LayerNorm(x)
403
+ x = self.net(x)
404
+ if self.relu:
405
+ x = F.relu(x, inplace=True)
406
+ return x # (N, L, D)
407
+
408
+
409
+ def build_model(args):
410
+ device = torch.device(args.device)
411
+
412
+ transformer = build_transformer(args)
413
+ position_embedding, txt_position_embedding = build_position_encoding(args)
414
+
415
+ model = Model(
416
+ transformer,
417
+ position_embedding,
418
+ txt_position_embedding,
419
+ txt_dim=args.t_feat_dim,
420
+ vid_dim=args.v_feat_dim,
421
+ input_dropout=args.input_dropout,
422
+ span_loss_type=args.span_loss_type,
423
+ use_txt_pos=args.use_txt_pos,
424
+ n_input_proj=args.n_input_proj,
425
+ )
426
+
427
+ matcher = build_matcher(args)
428
+ weight_dict = {"loss_b": args.b_loss_coef,
429
+ "loss_g": args.g_loss_coef,
430
+ "loss_f": args.f_loss_coef,
431
+ "loss_s_intra": args.s_loss_intra_coef,
432
+ "loss_s_inter": args.s_loss_inter_coef}
433
+
434
+ if args.dset_type in ['mr', 'vlp']:
435
+ if 'tal' not in args.train_path:
436
+ losses = ['spans', 'labels', 'saliency']
437
+ else:
438
+ losses = ['spans', 'labels', 'saliency_cls']
439
+ elif args.dset_type in ['hl', 'vs']:
440
+ losses = ['labels', 'saliency']
441
+
442
+ criterion = SetCriterion(
443
+ matcher=matcher,
444
+ weight_dict=weight_dict, losses=losses,
445
+ eos_coef=args.eos_coef, temperature=args.temperature,
446
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
447
+ saliency_margin=args.saliency_margin,
448
+ )
449
+ criterion.to(device)
450
+ return model, criterion
model/univtg_ablation.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ weight_abalation_b = targets['weight_ablation'][:,0].unsqueeze(-1)
208
+ if weight_abalation_b.sum() == 0:
209
+ return {"loss_f": torch.tensor(0).cuda(), "loss_g": torch.tensor(0).cuda()}
210
+
211
+ mask_valid = (mask_valid * weight_abalation_b).bool()
212
+ mask_valid_full = (mask_valid_full * weight_abalation_b.unsqueeze(-1)).bool()
213
+
214
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
215
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
216
+
217
+ losses = {}
218
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
219
+ losses['loss_g'] = loss_giou.mean()
220
+ return losses
221
+
222
+ def loss_labels(self, outputs, targets, indices, log=True):
223
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
224
+ mask = targets['timestamp_mask'].bool()
225
+ mask_valid = targets['timestamp_window'].bool()
226
+ target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
227
+ target_classes[mask_valid] = 1
228
+ # target_classes = targets['timestamp_window'] # soft cls.
229
+ target_classes.float()
230
+ # pdb.set_trace()
231
+
232
+ weights = torch.zeros_like(target_classes).float()
233
+ weights[mask] = self.empty_weight[1]
234
+ weights[mask_valid] = self.empty_weight[0]
235
+
236
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
237
+
238
+ weight_abalation_f = targets['weight_ablation'][:,2].unsqueeze(-1)
239
+ if weight_abalation_f.sum() == 0:
240
+ return {"loss_f": torch.tensor(0).cuda()}
241
+
242
+ mask = mask * weight_abalation_f
243
+ loss_ce = loss_ce * weight_abalation_f
244
+ return {"loss_f": loss_ce.sum() / mask.sum()}
245
+ # return {"loss_f": loss_ce.sum() / (1 + mask_valid.sum())}
246
+
247
+ def loss_saliency(self, outputs, targets, indices, log=True):
248
+ """higher scores for positive clips"""
249
+ if "saliency_pos_labels" not in targets:
250
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
251
+ saliency_scores = targets["saliency_scores"]
252
+ if saliency_scores.sum() == 0:
253
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
254
+
255
+ # * inter-vid mode
256
+ vid_mem_proj = outputs["vid_mem_proj"]
257
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
258
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
259
+
260
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
261
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
262
+ sim = sim_matrix(vid_feats, txt_feats)
263
+
264
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
265
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
266
+
267
+ # sum over positives
268
+ idiag = torch.diag(i_logsm)
269
+ jdiag = torch.diag(j_logsm)
270
+
271
+ weight_abalation_s = targets['weight_ablation'][:,3].bool()
272
+ if weight_abalation_s.sum() == 0:
273
+ return {"loss_s_inter": torch.tensor(0).cuda(),
274
+ "loss_s_intra": torch.tensor(0).cuda()}
275
+
276
+ _idiag = idiag[weight_abalation_s]
277
+ _jdiag = jdiag[weight_abalation_s]
278
+
279
+ loss_i = _idiag.sum() / len(_idiag)
280
+ loss_j = _jdiag.sum() / len(_jdiag)
281
+
282
+ loss_saliency_inter = - loss_i - loss_j
283
+
284
+ # * intra-vid mode
285
+ mask = targets['timestamp_mask']
286
+ selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
287
+ neg_indices_in = (saliency_scores < selected_scores)
288
+ neg_indices_in[batch_indices, pos_indices] = True
289
+ mask_invalid = neg_indices_in * mask.bool()
290
+
291
+ sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
292
+ sim_in = sim_in + (mask_invalid + 1e-45).log()
293
+ logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
294
+ logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
295
+
296
+ pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
297
+ pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
298
+ _pos_logsm_in_i = pos_logsm_in_i[weight_abalation_s]
299
+ _pos_logsm_in_j = pos_logsm_in_j[weight_abalation_s]
300
+
301
+ loss_in_i = _pos_logsm_in_i.sum() / len(_pos_logsm_in_i)
302
+ loss_in_j = _pos_logsm_in_j.sum() / len(_pos_logsm_in_j)
303
+
304
+ loss_saliency_intra = - loss_in_i - loss_in_j
305
+
306
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
307
+
308
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
309
+ """higher scores for positive clips"""
310
+ if "saliency_pos_labels" not in targets:
311
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
312
+ saliency_scores = targets["saliency_scores"]
313
+ if saliency_scores.sum() == 0:
314
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
315
+
316
+ # * inter-vid mode
317
+ vid_mem_proj = outputs["vid_mem_proj"]
318
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
319
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
320
+
321
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
322
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
323
+ sim = sim_matrix(vid_feats, txt_feats)
324
+
325
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
326
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
327
+
328
+ # sum over positives
329
+ idiag = torch.diag(i_logsm)
330
+ jdiag = torch.diag(j_logsm)
331
+ loss_i = idiag.sum() / len(idiag)
332
+ loss_j = jdiag.sum() / len(jdiag)
333
+
334
+ loss_saliency_inter = - loss_i - loss_j
335
+
336
+ # * intra-vid mode
337
+ if 'cls_idx' not in targets.keys(): # eval
338
+ return {"loss_s_inter": loss_saliency_inter}
339
+
340
+ cls_indices = targets['cls_idx'].bool()
341
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
342
+ sim_cls = sim_matrix(vid_feats, cls_feats)
343
+
344
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
345
+ idiag_cls = i_logsm_cls[cls_indices]
346
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
347
+
348
+ loss_saliency_intra = - loss_cls_i
349
+
350
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
351
+
352
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
353
+ loss_map = {
354
+ "spans": self.loss_spans,
355
+ "labels": self.loss_labels,
356
+ "saliency": self.loss_saliency,
357
+ "saliency_cls": self.loss_saliency_cls,
358
+ }
359
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
360
+ return loss_map[loss](outputs, targets, indices, **kwargs)
361
+
362
+ def forward(self, outputs, targets, hl_only=False):
363
+ """ This performs the loss computation.
364
+ Parameters:
365
+ outputs: dict of tensors, see the output specification of the model for the format
366
+ targets: list of dicts, such that len(targets) == batch_size.
367
+ The expected keys in each dict depends on the losses applied, see each loss' doc
368
+ """
369
+ indices = None
370
+ # Compute all the requested losses
371
+ losses = {}
372
+ for loss in self.losses:
373
+ losses.update(self.get_loss(loss, outputs, targets, indices))
374
+
375
+ return losses
376
+
377
+ class MLP(nn.Module):
378
+ """ Very simple multi-layer perceptron (also called FFN)"""
379
+
380
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
381
+ super().__init__()
382
+ self.num_layers = num_layers
383
+ h = [hidden_dim] * (num_layers - 1)
384
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
385
+
386
+ def forward(self, x):
387
+ for i, layer in enumerate(self.layers):
388
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
389
+ return x
390
+
391
+ class Conv(nn.Module):
392
+ """ Very simple multi-layer perceptron (also called FFN)"""
393
+
394
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
395
+ super().__init__()
396
+ self.num_layers = num_layers
397
+ h = [hidden_dim] * (num_layers - 1)
398
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
399
+ self.layers = nn.ModuleList(
400
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
401
+ for n, k in zip([input_dim] + h, h + [output_dim]))
402
+ def forward(self, x):
403
+ x = x.permute(0,2,1)
404
+ for i, layer in enumerate(self.layers):
405
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
406
+ return x.permute(0, 2, 1)
407
+
408
+ class LinearLayer(nn.Module):
409
+ """linear layer configurable with layer normalization, dropout, ReLU."""
410
+
411
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
412
+ super(LinearLayer, self).__init__()
413
+ self.relu = relu
414
+ self.layer_norm = layer_norm
415
+ if layer_norm:
416
+ self.LayerNorm = nn.LayerNorm(in_hsz)
417
+ layers = [
418
+ nn.Dropout(dropout),
419
+ nn.Linear(in_hsz, out_hsz)
420
+ ]
421
+ self.net = nn.Sequential(*layers)
422
+
423
+ def forward(self, x):
424
+ """(N, L, D)"""
425
+ if self.layer_norm:
426
+ x = self.LayerNorm(x)
427
+ x = self.net(x)
428
+ if self.relu:
429
+ x = F.relu(x, inplace=True)
430
+ return x # (N, L, D)
431
+
432
+
433
+ def build_model(args):
434
+ device = torch.device(args.device)
435
+
436
+ transformer = build_transformer(args)
437
+ position_embedding, txt_position_embedding = build_position_encoding(args)
438
+
439
+ model = Model(
440
+ transformer,
441
+ position_embedding,
442
+ txt_position_embedding,
443
+ txt_dim=args.t_feat_dim,
444
+ vid_dim=args.v_feat_dim,
445
+ input_dropout=args.input_dropout,
446
+ span_loss_type=args.span_loss_type,
447
+ use_txt_pos=args.use_txt_pos,
448
+ n_input_proj=args.n_input_proj,
449
+ )
450
+
451
+ matcher = build_matcher(args)
452
+ weight_dict = {"loss_b": args.b_loss_coef,
453
+ "loss_g": args.g_loss_coef,
454
+ "loss_f": args.f_loss_coef,
455
+ "loss_s_intra": args.s_loss_intra_coef,
456
+ "loss_s_inter": args.s_loss_inter_coef}
457
+
458
+ if args.dset_type in ['mr', 'vlp']:
459
+ if 'tal' not in args.train_path:
460
+ losses = ['spans', 'labels', 'saliency']
461
+ else:
462
+ losses = ['spans', 'labels', 'saliency_cls']
463
+ elif args.dset_type in ['hl', 'vs']:
464
+ losses = ['labels', 'saliency']
465
+
466
+ criterion = SetCriterion(
467
+ matcher=matcher,
468
+ weight_dict=weight_dict, losses=losses,
469
+ eos_coef=args.eos_coef, temperature=args.temperature,
470
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
471
+ saliency_margin=args.saliency_margin,
472
+ )
473
+ criterion.to(device)
474
+ return model, criterion
model/univtg_qfvs.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ import numpy as np
6
+
7
+ from model.transformer_encoder_droppath import build_transformer
8
+ from model.matcher import build_matcher
9
+ from model.position_encoding import build_position_encoding
10
+ from utils.span_utils import generalized_temporal_iou, span_cxw_to_xx
11
+
12
+ def init_weights(module):
13
+ if isinstance(module, (nn.Linear, nn.Embedding)):
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+ elif isinstance(module, nn.LayerNorm):
16
+ module.bias.data.zero_()
17
+ module.weight.data.fill_(1.0)
18
+
19
+ if isinstance(module, nn.Linear) and module.bias is not None:
20
+ module.bias.data.zero_()
21
+
22
+ def mask_logits(inputs, mask, mask_value=-1e30):
23
+ mask = mask.type(torch.float32)
24
+ return inputs + (1.0 - mask) * mask_value
25
+
26
+ def sim_matrix(a, b, eps=1e-8):
27
+ """
28
+ added eps for numerical stability
29
+ """
30
+ a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
31
+ a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
32
+ b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
33
+ sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
34
+ return sim_mt
35
+
36
+ class WeightedPool(nn.Module):
37
+ def __init__(self, dim):
38
+ super(WeightedPool, self).__init__()
39
+ weight = torch.empty(dim, 1)
40
+ nn.init.xavier_uniform_(weight)
41
+ self.weight = nn.Parameter(weight, requires_grad=True)
42
+
43
+ def forward(self, x, mask):
44
+ alpha = torch.tensordot(x, self.weight, dims=1) # shape = (batch_size, seq_length, 1)
45
+ alpha = mask_logits(alpha, mask=mask.unsqueeze(2))
46
+ alphas = nn.Softmax(dim=1)(alpha)
47
+ pooled_x = torch.matmul(x.transpose(1, 2), alphas) # (batch_size, dim, 1)
48
+ pooled_x = pooled_x.squeeze(2)
49
+ return pooled_x
50
+
51
+ class Model(nn.Module):
52
+ """ This is the UniVTG module that performs moment localization. """
53
+
54
+ def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
55
+ input_dropout, aux_loss=False,
56
+ max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2):
57
+ """ Initializes the model.
58
+ Parameters:
59
+ transformer: torch module of the transformer architecture. See transformer.py
60
+ position_embed: torch module of the position_embedding, See position_encoding.py
61
+ txt_position_embed: position_embedding for text
62
+ txt_dim: int, text query input dimension
63
+ vid_dim: int, video feature input dimension
64
+ max_v_l: int, maximum #clips in videos
65
+ span_loss_type: str, one of [l1, ce]
66
+ l1: (center-x, width) regression.
67
+ ce: (st_idx, ed_idx) classification.
68
+ # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
69
+ # background_thd: float, intersection over prediction <= background_thd: labeled background
70
+ """
71
+ super().__init__()
72
+ self.transformer = transformer
73
+ self.position_embed = position_embed
74
+ self.txt_position_embed = txt_position_embed
75
+ hidden_dim = transformer.d_model
76
+ self.span_loss_type = span_loss_type
77
+ self.max_v_l = max_v_l
78
+ span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
79
+
80
+ self.token_type_embeddings = nn.Embedding(2, hidden_dim)
81
+ self.token_type_embeddings.apply(init_weights)
82
+
83
+ # Conv projector
84
+ self.span_embed = Conv(hidden_dim, hidden_dim, span_pred_dim, 3, kernel_size=3)
85
+ self.class_embed = Conv(hidden_dim, hidden_dim, 1, 3, kernel_size=3) # 0: background, 1: foreground
86
+
87
+ self.use_txt_pos = use_txt_pos
88
+ self.n_input_proj = n_input_proj
89
+ relu_args = [True] * 3
90
+ relu_args[n_input_proj-1] = False
91
+ self.input_txt_proj = nn.Sequential(*[
92
+ LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
93
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
94
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
95
+ ][:n_input_proj])
96
+ self.input_vid_proj = nn.Sequential(*[
97
+ LinearLayer(vid_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
98
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
99
+ LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
100
+ ][:n_input_proj])
101
+
102
+ # MLP Projector
103
+ self.weightedpool = WeightedPool(hidden_dim)
104
+
105
+ def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, src_cls=None, src_cls_mask=None):
106
+ bs = src_vid.shape[0]
107
+ src_vid = self.input_vid_proj(src_vid)
108
+ src_txt = self.input_txt_proj(src_txt)
109
+ if src_cls is not None:
110
+ src_cls = self.input_txt_proj(src_cls)
111
+
112
+ # type token.
113
+ src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1))
114
+ src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
115
+ if src_cls is not None:
116
+ src_cls = src_cls + self.token_type_embeddings(torch.zeros_like(src_cls_mask.long()))
117
+
118
+ src = torch.cat([src_vid, src_txt], dim=1) # (bsz, L_vid+L_txt, d)
119
+ mask = torch.cat([src_vid_mask, src_txt_mask], dim=1).bool() # (bsz, L_vid+L_txt)
120
+
121
+ pos_vid = self.position_embed(src_vid, src_vid_mask) # (bsz, L_vid, d)
122
+ pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt) # (bsz, L_txt, d)
123
+ pos = torch.cat([pos_vid, pos_txt], dim=1)
124
+
125
+ memory = self.transformer(src, ~mask, pos)
126
+ vid_mem = memory[:, :src_vid.shape[1], :] # (bsz, L_vid, d)
127
+
128
+ outputs_class = self.class_embed(vid_mem).sigmoid() # (#layers, batch_size, #queries, #classes)
129
+ outputs_coord = self.span_embed(vid_mem) # (#layers, bsz, #queries, 2 or max_v_l * 2)
130
+
131
+ if self.span_loss_type == "l1":
132
+ outputs_coord = outputs_coord.sigmoid()
133
+ idx_mask = torch.tensor((-1, 1)).unsqueeze(0).unsqueeze(0).cuda()
134
+ idx_mask = idx_mask.repeat(outputs_coord.shape[0], outputs_coord.shape[1], 1)
135
+ outputs_coord = outputs_coord * idx_mask
136
+ else:
137
+ raise NotImplementedError
138
+
139
+ out = {'pred_logits': outputs_class, 'pred_spans': outputs_coord,
140
+ 'src_vid_mask': src_vid_mask}
141
+
142
+ vid_mem_proj = src_vid
143
+
144
+ # word-level -> sentence-level
145
+ txt_mem_proj = self.weightedpool(src_txt, src_txt_mask).unsqueeze(1)
146
+ sim = F.cosine_similarity(vid_mem_proj, txt_mem_proj, dim=-1) + (src_vid_mask + 1e-45).log()
147
+
148
+ out["vid_mem_proj"] = vid_mem_proj
149
+ out["txt_mem_proj"] = txt_mem_proj
150
+ if src_cls is not None:
151
+ cls_mem_proj = self.weightedpool(src_cls, src_cls_mask)
152
+ out["cls_mem_proj"] = cls_mem_proj
153
+ out["saliency_scores"] = sim
154
+ return out
155
+
156
+ class SetCriterion(nn.Module):
157
+ """ This class computes the loss for DETR.
158
+ The process happens in two steps:
159
+ 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
160
+ 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
161
+ """
162
+
163
+ def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
164
+ saliency_margin=1):
165
+ """ Create the criterion.
166
+ Parameters:
167
+ matcher: module able to compute a matching between targets and proposals
168
+ weight_dict: dict containing as key the names of the losses and as values their relative weight.
169
+ eos_coef: relative classification weight applied to the no-object category
170
+ losses: list of all the losses to be applied. See get_loss for list of available losses.
171
+ temperature: float, temperature for NCE loss
172
+ span_loss_type: str, [l1, ce]
173
+ max_v_l: int,
174
+ saliency_margin: float
175
+ """
176
+ super().__init__()
177
+ self.matcher = matcher
178
+ self.weight_dict = weight_dict
179
+ self.losses = losses
180
+ self.temperature = temperature
181
+ self.span_loss_type = span_loss_type
182
+ self.max_v_l = max_v_l
183
+ self.saliency_margin = saliency_margin
184
+ self.temperature = 0.07
185
+
186
+ # foreground and background classification
187
+ self.foreground_label = 0
188
+ self.background_label = 1
189
+ self.eos_coef = eos_coef
190
+ empty_weight = torch.ones(2)
191
+ empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
192
+ self.register_buffer('empty_weight', empty_weight)
193
+
194
+ def loss_spans(self, outputs, targets, indices):
195
+ assert 'pred_spans' in outputs
196
+
197
+ start_spans = targets['timestamp']
198
+ pred_spans = outputs['pred_spans']
199
+ src_spans = start_spans + pred_spans
200
+ gt_spans = targets['span_labels_nn']
201
+
202
+ mask = targets['timestamp_mask'].bool()
203
+ mask_full = targets['timestamp_mask'].unsqueeze(2).repeat(1, 1, 2)
204
+ mask_valid = targets['timestamp_window'].bool()
205
+ mask_valid_full = targets['timestamp_window'].unsqueeze(2).repeat(1, 1, 2)
206
+
207
+ loss_span = F.smooth_l1_loss(src_spans, gt_spans, reduction='none') * mask_valid_full
208
+ loss_giou = 1 - torch.diag(generalized_temporal_iou(src_spans[mask_valid], gt_spans[mask_valid]))
209
+
210
+ losses = {}
211
+ losses['loss_b'] = loss_span.sum() / mask_valid.sum()
212
+ losses['loss_g'] = loss_giou.mean()
213
+ return losses
214
+
215
+ def loss_labels(self, outputs, targets, indices, log=True):
216
+ saliency_scores = targets["saliency_scores"]
217
+ if saliency_scores.sum() == 0:
218
+ return {"loss_f": 0.}
219
+
220
+ src_logits = outputs['pred_logits'].squeeze(-1) # (batch_size, #queries, #classes=2)
221
+ target_classes = targets["saliency_scores"].squeeze()
222
+
223
+ weights = torch.ones_like(target_classes).float() * self.empty_weight[1]
224
+ weights[target_classes.bool()] = self.empty_weight[0]
225
+
226
+ loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), reduction="none")
227
+ return {"loss_f": loss_ce.sum() / target_classes.sum()}
228
+ # return {"loss_f": loss_ce.sum() / len(target_classes)}
229
+
230
+ # mask = targets['timestamp_mask'].bool()
231
+ # mask_valid = targets['timestamp_window'].bool()
232
+ # target_classes = torch.full(src_logits.shape[:2], 0, dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
233
+ # target_classes[mask_valid] = 1
234
+ # # target_classes = targets['timestamp_window'] # soft cls.
235
+ # target_classes.float()
236
+ # # pdb.set_trace()
237
+
238
+ # weights = torch.zeros_like(target_classes).float()
239
+ # weights[mask] = self.empty_weight[1]
240
+ # weights[mask_valid] = self.empty_weight[0]
241
+
242
+ # loss_ce = F.binary_cross_entropy(src_logits, target_classes.float(), weight=weights, reduction="none") * mask
243
+ # # return {"loss_f": loss_ce.sum() / mask.sum()}
244
+ # return {"loss_f": loss_ce.sum() / mask_valid.sum()}
245
+
246
+ def loss_saliency(self, outputs, targets, indices, log=True):
247
+ """higher scores for positive clips"""
248
+ if "saliency_pos_labels" not in targets:
249
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
250
+ saliency_scores = targets["saliency_scores"]
251
+ if saliency_scores.sum() == 0:
252
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
253
+
254
+ # * qfvs mil-nce mode
255
+ pos_indices = saliency_scores.squeeze() > 0
256
+
257
+ sim = outputs['saliency_scores']
258
+ sim_soft = F.softmax(sim / self.temperature, dim=0)
259
+ sim_log = torch.log(sim_soft[pos_indices])
260
+ loss_saliency_intra = -sim_log.sum() / len(sim_log)
261
+ return {"loss_s_inter": 0., "loss_s_intra": loss_saliency_intra}
262
+
263
+ # * inter-vid mode
264
+ # vid_mem_proj = outputs["vid_mem_proj"]
265
+ # pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
266
+ # batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
267
+
268
+ # vid_feats = vid_mem_proj[batch_indices, pos_indices]
269
+ # txt_feats = outputs["txt_mem_proj"].squeeze(1)
270
+ # sim = sim_matrix(vid_feats, txt_feats)
271
+
272
+ # i_logsm = F.log_softmax(sim / self.temperature, dim=1)
273
+ # j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
274
+
275
+ # # sum over positives
276
+ # idiag = torch.diag(i_logsm)
277
+ # jdiag = torch.diag(j_logsm)
278
+ # loss_i = idiag.sum() / len(idiag)
279
+ # loss_j = jdiag.sum() / len(jdiag)
280
+
281
+ # loss_saliency_inter = - loss_i - loss_j
282
+
283
+ # # * intra-vid mode
284
+ # mask = targets['timestamp_mask']
285
+ # selected_scores = saliency_scores[batch_indices, pos_indices].unsqueeze(-1)
286
+ # neg_indices_in = (saliency_scores < selected_scores)
287
+ # neg_indices_in[batch_indices, pos_indices] = True
288
+ # mask_invalid = neg_indices_in * mask.bool()
289
+
290
+ # sim_in = F.cosine_similarity(vid_mem_proj, txt_feats.unsqueeze(1), dim=-1)
291
+ # sim_in = sim_in + (mask_invalid + 1e-45).log()
292
+ # logsm_in_i = F.log_softmax(sim_in / self.temperature, dim=1)
293
+ # logsm_in_j = F.log_softmax(sim_in.t() / self.temperature, dim=1)
294
+
295
+ # pos_logsm_in_i = logsm_in_i[batch_indices, pos_indices]
296
+ # pos_logsm_in_j = logsm_in_j[pos_indices, batch_indices]
297
+ # loss_in_i = pos_logsm_in_i.sum() / len(pos_logsm_in_i)
298
+ # loss_in_j = pos_logsm_in_j.sum() / len(pos_logsm_in_j)
299
+
300
+ # loss_saliency_intra = - loss_in_i - loss_in_j
301
+
302
+ # return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
303
+
304
+ def loss_saliency_cls(self, outputs, targets, indices, log=True):
305
+ """higher scores for positive clips"""
306
+ if "saliency_pos_labels" not in targets:
307
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
308
+ saliency_scores = targets["saliency_scores"]
309
+ if saliency_scores.sum() == 0:
310
+ return {"loss_s_inter": 0., "loss_s_intra": 0.}
311
+
312
+ # * inter-vid mode
313
+ vid_mem_proj = outputs["vid_mem_proj"]
314
+ pos_indices = targets["saliency_pos_labels"][:,0].long() # (N, #pairs)
315
+ batch_indices = torch.arange(len(vid_mem_proj)).to(vid_mem_proj.device)
316
+
317
+ vid_feats = vid_mem_proj[batch_indices, pos_indices]
318
+ txt_feats = outputs["txt_mem_proj"].squeeze(1)
319
+ sim = sim_matrix(vid_feats, txt_feats)
320
+
321
+ i_logsm = F.log_softmax(sim / self.temperature, dim=1)
322
+ j_logsm = F.log_softmax(sim.t() /self.temperature, dim=1)
323
+
324
+ # sum over positives
325
+ idiag = torch.diag(i_logsm)
326
+ jdiag = torch.diag(j_logsm)
327
+ loss_i = idiag.sum() / len(idiag)
328
+ loss_j = jdiag.sum() / len(jdiag)
329
+
330
+ loss_saliency_inter = - loss_i - loss_j
331
+
332
+ # * intra-vid mode
333
+ if 'cls_idx' not in targets.keys(): # eval
334
+ return {"loss_s_inter": loss_saliency_inter}
335
+
336
+ cls_indices = targets['cls_idx'].bool()
337
+ cls_feats = outputs["cls_mem_proj"].squeeze(1)
338
+ sim_cls = sim_matrix(vid_feats, cls_feats)
339
+
340
+ i_logsm_cls = F.log_softmax(sim_cls / self.temperature, dim=1)
341
+ idiag_cls = i_logsm_cls[cls_indices]
342
+ loss_cls_i = idiag_cls.sum() / len(idiag_cls)
343
+
344
+ loss_saliency_intra = - loss_cls_i
345
+
346
+ return {"loss_s_inter": loss_saliency_inter, "loss_s_intra": loss_saliency_intra}
347
+
348
+ def get_loss(self, loss, outputs, targets, indices, **kwargs):
349
+ loss_map = {
350
+ "spans": self.loss_spans,
351
+ "labels": self.loss_labels,
352
+ "saliency": self.loss_saliency,
353
+ "saliency_cls": self.loss_saliency_cls,
354
+ }
355
+ assert loss in loss_map, f'do you really want to compute {loss} loss?'
356
+ return loss_map[loss](outputs, targets, indices, **kwargs)
357
+
358
+ def forward(self, outputs, targets, mask_GT=None):
359
+ """ This performs the loss computation.
360
+ Parameters:
361
+ outputs: dict of tensors, see the output specification of the model for the format
362
+ targets: list of dicts, such that len(targets) == batch_size.
363
+ The expected keys in each dict depends on the losses applied, see each loss' doc
364
+ """
365
+ indices = None
366
+ # Compute all the requested losses
367
+ losses = {}
368
+ outputs['pred_logits'] = outputs['pred_logits'].reshape(1, -1).masked_select(mask_GT[0])
369
+ count = mask_GT.sum()
370
+ outputs['saliency_scores'] = outputs['saliency_scores'].reshape(1, -1).masked_select(mask_GT[0])
371
+ # targets['saliency_scores'] = targets['saliency_scores'].masked_select(mask_GT[0])
372
+ targets['saliency_scores'] = targets['saliency_scores'][0,:count]
373
+
374
+ for loss in self.losses:
375
+ losses.update(self.get_loss(loss, outputs, targets, indices))
376
+
377
+ return losses
378
+
379
+ class MLP(nn.Module):
380
+ """ Very simple multi-layer perceptron (also called FFN)"""
381
+
382
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
383
+ super().__init__()
384
+ self.num_layers = num_layers
385
+ h = [hidden_dim] * (num_layers - 1)
386
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
387
+
388
+ def forward(self, x):
389
+ for i, layer in enumerate(self.layers):
390
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
391
+ return x
392
+
393
+ class Conv(nn.Module):
394
+ """ Very simple multi-layer perceptron (also called FFN)"""
395
+
396
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers, kernel_size):
397
+ super().__init__()
398
+ self.num_layers = num_layers
399
+ h = [hidden_dim] * (num_layers - 1)
400
+ # self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
401
+ self.layers = nn.ModuleList(
402
+ nn.Conv1d(n, k, kernel_size=kernel_size, stride=1, padding=kernel_size//2, dilation=1, groups=1, bias=True, padding_mode='zeros')
403
+ for n, k in zip([input_dim] + h, h + [output_dim]))
404
+ def forward(self, x):
405
+ x = x.permute(0,2,1)
406
+ for i, layer in enumerate(self.layers):
407
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
408
+ return x.permute(0, 2, 1)
409
+
410
+ class LinearLayer(nn.Module):
411
+ """linear layer configurable with layer normalization, dropout, ReLU."""
412
+
413
+ def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
414
+ super(LinearLayer, self).__init__()
415
+ self.relu = relu
416
+ self.layer_norm = layer_norm
417
+ if layer_norm:
418
+ self.LayerNorm = nn.LayerNorm(in_hsz)
419
+ layers = [
420
+ nn.Dropout(dropout),
421
+ nn.Linear(in_hsz, out_hsz)
422
+ ]
423
+ self.net = nn.Sequential(*layers)
424
+
425
+ def forward(self, x):
426
+ """(N, L, D)"""
427
+ if self.layer_norm:
428
+ x = self.LayerNorm(x)
429
+ x = self.net(x)
430
+ if self.relu:
431
+ x = F.relu(x, inplace=True)
432
+ return x # (N, L, D)
433
+
434
+
435
+ def build_model(args):
436
+ device = torch.device(args.device)
437
+
438
+ transformer = build_transformer(args)
439
+ position_embedding, txt_position_embedding = build_position_encoding(args)
440
+
441
+ model = Model(
442
+ transformer,
443
+ position_embedding,
444
+ txt_position_embedding,
445
+ txt_dim=args.t_feat_dim,
446
+ vid_dim=args.v_feat_dim,
447
+ input_dropout=args.input_dropout,
448
+ span_loss_type=args.span_loss_type,
449
+ use_txt_pos=args.use_txt_pos,
450
+ n_input_proj=args.n_input_proj,
451
+ )
452
+
453
+ matcher = build_matcher(args)
454
+ weight_dict = {"loss_b": args.b_loss_coef,
455
+ "loss_g": args.g_loss_coef,
456
+ "loss_f": args.f_loss_coef,
457
+ "loss_s_intra": args.s_loss_intra_coef,
458
+ "loss_s_inter": args.s_loss_inter_coef}
459
+
460
+ if args.dset_type in ['mr', 'vlp']:
461
+ if 'tal' not in args.train_path:
462
+ losses = ['spans', 'labels', 'saliency']
463
+ else:
464
+ losses = ['spans', 'labels', 'saliency_cls']
465
+ elif args.dset_type in ['hl', 'vs']:
466
+ losses = ['labels', 'saliency']
467
+
468
+ criterion = SetCriterion(
469
+ matcher=matcher,
470
+ weight_dict=weight_dict, losses=losses,
471
+ eos_coef=args.eos_coef, temperature=args.temperature,
472
+ span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
473
+ saliency_margin=args.saliency_margin,
474
+ )
475
+ criterion.to(device)
476
+ return model, criterion
requirements.txt ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.2.0
2
+ accelerate==0.19.0
3
+ aiodns==3.0.0
4
+ aiofiles==23.1.0
5
+ aiohttp==3.8.3
6
+ aiohttp-socks==0.7.1
7
+ aiosignal==1.3.1
8
+ altair==5.0.1
9
+ antiorm==1.2.1
10
+ antlr4-python3-runtime==4.9.3
11
+ anyio==3.7.0
12
+ appdirs==1.4.4
13
+ argilla==1.8.0
14
+ argon2-cffi==21.3.0
15
+ argon2-cffi-bindings==21.2.0
16
+ asttokens==2.0.7
17
+ async-timeout==4.0.2
18
+ attrs==22.1.0
19
+ Babel==2.12.1
20
+ backcall==0.2.0
21
+ backoff==2.2.1
22
+ beautifulsoup4==4.11.1
23
+ bert-score==0.3.13
24
+ black==22.3.0
25
+ bleach==5.0.1
26
+ blis==0.7.9
27
+ boto3==1.24.84
28
+ botocore==1.27.84
29
+ Brotli==1.0.9
30
+ brotlipy==0.7.0
31
+ cachetools==5.2.0
32
+ catalogue==2.0.8
33
+ cchardet==2.1.7
34
+ certifi==2023.5.7
35
+ cffi==1.15.1
36
+ chardet==5.1.0
37
+ charset-normalizer==2.1.1
38
+ cinemagoer==2023.5.1
39
+ click==8.1.3
40
+ cloudpickle==2.2.0
41
+ cmake==3.26.3
42
+ coloredlogs==15.0.1
43
+ colorlog==6.7.0
44
+ commonmark==0.9.1
45
+ confection==0.0.4
46
+ contourpy==1.0.6
47
+ cryptography==37.0.1
48
+ cycler==0.11.0
49
+ cymem==2.0.7
50
+ dataclasses==0.6
51
+ dataclasses-json==0.5.7
52
+ dataflow==0.9.5
53
+ db==0.1.1
54
+ db-sqlite3==0.0.1
55
+ debugpy==1.6.3
56
+ decoder==0.5
57
+ decorator==4.4.2
58
+ decord==0.6.0
59
+ defusedxml==0.7.1
60
+ Deprecated==1.2.14
61
+ detectron2==0.6
62
+ docker==6.0.0
63
+ docker-pycreds==0.4.0
64
+ easydict==1.9
65
+ ego4d==1.2.5
66
+ einops==0.6.0
67
+ elastic-transport==8.4.0
68
+ elasticsearch==8.5.0
69
+ entrypoints==0.4
70
+ et-xmlfile==1.1.0
71
+ exceptiongroup==1.1.1
72
+ executing==0.10.0
73
+ fairscale==0.4.12
74
+ fake-useragent==0.1.14
75
+ fastapi==0.98.0
76
+ fastjsonschema==2.16.1
77
+ ffmpeg==1.4
78
+ ffmpeg-python==0.2.0
79
+ ffmpy==0.3.0
80
+ ffprobe==0.5
81
+ filelock==3.7.1
82
+ fonttools==4.38.0
83
+ frozenlist==1.3.3
84
+ fsspec==2023.5.0
85
+ ftfy==6.1.1
86
+ future==0.18.2
87
+ fvcore==0.1.5.post20220512
88
+ gdown==4.7.1
89
+ gensim==4.2.0
90
+ geographiclib==2.0
91
+ geopy==2.3.0
92
+ gitdb==4.0.10
93
+ GitPython==3.1.31
94
+ glide-text2im==0.0.0
95
+ google-api-core==2.11.1
96
+ google-api-python-client==2.95.0
97
+ google-auth==2.22.0
98
+ google-auth-httplib2==0.1.0
99
+ google-auth-oauthlib==0.4.6
100
+ google-cloud==0.34.0
101
+ google-cloud-vision==3.4.4
102
+ google-measurement-protocol==1.1.0
103
+ googleapis-common-protos==1.59.1
104
+ googletransx==2.4.2
105
+ gradio==3.23.0
106
+ greenlet==2.0.2
107
+ grpcio==1.56.2
108
+ grpcio-status==1.56.2
109
+ h11==0.14.0
110
+ h5py==3.7.0
111
+ httpcore==0.16.3
112
+ httplib2==0.22.0
113
+ httpx==0.23.3
114
+ huggingface-hub==0.15.1
115
+ humanfriendly==10.0
116
+ hydra-core==1.2.0
117
+ idna==3.3
118
+ imageio==2.31.0
119
+ imageio-ffmpeg==0.4.7
120
+ importlib-metadata==4.12.0
121
+ importlib-resources==5.9.0
122
+ iopath==0.1.9
123
+ ipdb==0.13.11
124
+ ipykernel==6.15.3
125
+ ipython==8.4.0
126
+ ipython-genutils==0.2.0
127
+ ipywidgets==8.0.2
128
+ jedi==0.18.1
129
+ Jinja2==3.1.2
130
+ jmespath==1.0.1
131
+ joblib==1.1.0
132
+ jsonlines==3.1.0
133
+ jsonschema==4.16.0
134
+ jupyter==1.0.0
135
+ jupyter_client==7.3.5
136
+ jupyter-console==6.4.4
137
+ jupyter-core==4.11.1
138
+ jupyterlab-pygments==0.2.2
139
+ jupyterlab-widgets==3.0.3
140
+ kiwisolver==1.4.4
141
+ langchain==0.0.191
142
+ langcodes==3.3.0
143
+ language-evaluation==0.1.0
144
+ lazy_loader==0.2
145
+ linkify-it-py==2.0.2
146
+ lit==16.0.5.post0
147
+ lxml==4.9.1
148
+ Markdown==3.4.1
149
+ markdown-it-py==2.2.0
150
+ markdown2==2.4.9
151
+ MarkupSafe==2.1.1
152
+ marshmallow==3.19.0
153
+ marshmallow-enum==1.5.1
154
+ matplotlib==3.6.2
155
+ matplotlib-inline==0.1.3
156
+ mdit-py-plugins==0.3.3
157
+ mdurl==0.1.2
158
+ mistune==2.0.4
159
+ mkl-fft==1.3.1
160
+ mkl-random==1.2.2
161
+ mkl-service==2.4.0
162
+ monotonic==1.6
163
+ more-itertools==9.1.0
164
+ moviepy==1.0.3
165
+ mpmath==1.3.0
166
+ msg-parser==1.2.0
167
+ msgpack==1.0.4
168
+ msgpack-numpy==0.4.8
169
+ multidict==6.0.4
170
+ murmurhash==1.0.9
171
+ mutagen==1.46.0
172
+ mypy-extensions==0.4.3
173
+ nbclient==0.6.8
174
+ nbconvert==7.0.0
175
+ nbformat==5.5.0
176
+ nest-asyncio==1.5.5
177
+ networkx==2.8.7
178
+ nh3==0.2.13
179
+ nltk==3.7
180
+ nms-1d-cpu==0.0.0
181
+ nncore==0.3.6
182
+ notebook==6.4.12
183
+ numexpr==2.8.4
184
+ numpy==1.23.1
185
+ nvidia-cublas-cu11==11.10.3.66
186
+ nvidia-cuda-cupti-cu11==11.7.101
187
+ nvidia-cuda-nvrtc-cu11==11.7.99
188
+ nvidia-cuda-runtime-cu11==11.7.99
189
+ nvidia-cudnn-cu11==8.5.0.96
190
+ nvidia-cufft-cu11==10.9.0.58
191
+ nvidia-curand-cu11==10.2.10.91
192
+ nvidia-cusolver-cu11==11.4.0.1
193
+ nvidia-cusparse-cu11==11.7.4.91
194
+ nvidia-nccl-cu11==2.14.3
195
+ nvidia-nvtx-cu11==11.7.91
196
+ oauthlib==3.2.0
197
+ olefile==0.46
198
+ omegaconf==2.2.3
199
+ openai==0.27.7
200
+ openapi-schema-pydantic==1.2.4
201
+ opencv-python==4.5.4.58
202
+ openpyxl==3.1.2
203
+ orjson==3.9.1
204
+ ortools==9.4.1874
205
+ packaging==21.3
206
+ pandas==1.5.2
207
+ pandocfilters==1.5.0
208
+ parso==0.8.3
209
+ pathspec==0.10.1
210
+ pathtools==0.1.2
211
+ pathy==0.10.1
212
+ pdfminer.six==20221105
213
+ peft==0.3.0
214
+ pexpect==4.8.0
215
+ pickleshare==0.7.5
216
+ Pillow==9.3.0
217
+ pip==22.2.2
218
+ pkgutil_resolve_name==1.3.10
219
+ platformdirs==2.5.2
220
+ portalocker==2.5.1
221
+ preshed==3.0.8
222
+ prices==1.1.1
223
+ proglog==0.1.10
224
+ prometheus-client==0.14.1
225
+ prompt-toolkit==3.0.30
226
+ proto-plus==1.22.3
227
+ protobuf==3.20.1
228
+ psutil==5.9.2
229
+ ptyprocess==0.7.0
230
+ pure-eval==0.2.2
231
+ pyasn1==0.4.8
232
+ pyasn1-modules==0.2.8
233
+ pycares==4.2.2
234
+ pycipher==0.5.2
235
+ pycocoevalcap==1.2
236
+ pycocotools==2.0.5
237
+ pycparser==2.21
238
+ pycryptodomex==3.18.0
239
+ pydantic==1.10.8
240
+ pydot==1.4.2
241
+ pydub==0.25.1
242
+ pyfiglet==0.8.post1
243
+ Pygments==2.12.0
244
+ pynvml==11.5.0
245
+ pyOpenSSL==22.0.0
246
+ pypandoc==1.11
247
+ pyparsing==3.0.9
248
+ pyrsistent==0.18.1
249
+ PySocks==1.7.1
250
+ python-dateutil==2.8.2
251
+ python-docx==0.8.11
252
+ python-hostlist==1.21
253
+ python-magic==0.4.27
254
+ python-multipart==0.0.6
255
+ python-pptx==0.6.21
256
+ python-socks==2.0.3
257
+ pytz==2022.7
258
+ PyWavelets==1.4.1
259
+ PyYAML==6.0
260
+ pyzmq==23.2.1
261
+ qtconsole==5.3.2
262
+ QtPy==2.2.0
263
+ regex==2022.7.25
264
+ requests==2.28.1
265
+ requests-oauthlib==1.3.1
266
+ rfc3986==1.5.0
267
+ rich==13.0.1
268
+ rouge-score==0.1.2
269
+ rsa==4.9
270
+ ruamel.yaml==0.17.21
271
+ ruamel.yaml.clib==0.2.7
272
+ s3transfer==0.6.0
273
+ sacremoses==0.0.53
274
+ safetensors==0.3.1
275
+ schedule==1.1.0
276
+ scikit-image==0.21.0
277
+ scikit-learn==1.1.2
278
+ scipy==1.9.3
279
+ seaborn==0.12.0
280
+ semantic-version==2.10.0
281
+ Send2Trash==1.8.0
282
+ sentencepiece==0.1.99
283
+ sentry-sdk==1.26.0
284
+ setproctitle==1.3.2
285
+ setuptools==59.5.0
286
+ shortuuid==1.0.11
287
+ simplejson==3.17.6
288
+ six==1.16.0
289
+ smart-open==6.2.0
290
+ smmap==5.0.0
291
+ sniffio==1.3.0
292
+ soupsieve==2.3.2.post1
293
+ spacy==3.5.3
294
+ spacy-legacy==3.0.12
295
+ spacy-loggers==1.0.4
296
+ SQLAlchemy==2.0.15
297
+ srsly==2.4.6
298
+ stack-data==0.4.0
299
+ starlette==0.27.0
300
+ svgwrite==1.4.3
301
+ sympy==1.12
302
+ tabulate==0.8.10
303
+ tenacity==8.2.2
304
+ tensorboard==2.9.1
305
+ tensorboard-data-server==0.6.1
306
+ tensorboard-plugin-wit==1.8.1
307
+ termcolor==1.1.0
308
+ terminado==0.15.0
309
+ terminaltables==3.1.10
310
+ thinc==8.1.10
311
+ threadpoolctl==3.1.0
312
+ tifffile==2023.4.12
313
+ timm==0.4.12
314
+ tinycss2==1.1.1
315
+ tokenizers==0.13.2
316
+ tomli==2.0.1
317
+ toolz==0.12.0
318
+ torch==2.0.1
319
+ torchaudio==0.9.0a0+33b2469
320
+ torchdata==0.6.1
321
+ torchtext==0.15.2
322
+ torchvision==0.10.0a0
323
+ tornado==6.2
324
+ tqdm==4.64.1
325
+ traitlets==5.3.0
326
+ transformers==4.28.1
327
+ triton==2.0.0
328
+ twint==2.1.21
329
+ typer==0.7.0
330
+ typing_extensions==4.3.0
331
+ typing-inspect==0.9.0
332
+ uc-micro-py==1.0.2
333
+ unstructured==0.7.1
334
+ uritemplate==4.1.1
335
+ urllib3==1.26.12
336
+ uvicorn==0.22.0
337
+ wandb==0.15.4
338
+ warmup-scheduler==0.3
339
+ wasabi==1.1.2
340
+ wavedrom==2.0.3.post3
341
+ wcwidth==0.2.5
342
+ webencodings==0.5.1
343
+ websocket-client==1.4.1
344
+ websockets==11.0.3
345
+ Werkzeug==2.2.1
346
+ wheel==0.37.1
347
+ widgetsnbextension==4.0.3
348
+ wrapt==1.14.1
349
+ xlrd==2.0.1
350
+ XlsxWriter==3.1.2
351
+ yacs==0.1.8
352
+ yarl==1.9.2
353
+ youtube-dl==2021.12.17
354
+ yt-dlp==2023.3.4
355
+ zipp==3.8.1
results/omni/opt.json ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dset_type": "vlp",
3
+ "dset_name": "vlp",
4
+ "domain_name": null,
5
+ "model_id": "univtg",
6
+ "exp_id": "omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1",
7
+ "device": 0,
8
+ "gpu_id": 0,
9
+ "debug": false,
10
+ "seed": 2018,
11
+ "local_rank": 0,
12
+ "eval_split_name": "val",
13
+ "data_ratio": 1.0,
14
+ "results_root": "results",
15
+ "num_workers": 8,
16
+ "no_pin_memory": false,
17
+ "bsz": 64,
18
+ "n_epoch": 100,
19
+ "max_es_cnt": 200,
20
+ "lr": 0.0001,
21
+ "lr_drop": 200,
22
+ "lr_gamma": 0.1,
23
+ "lr_warmup": 10.0,
24
+ "wd": 0.0001,
25
+ "grad_clip": 0.1,
26
+ "span_loss_type": "l1",
27
+ "b_loss_coef": 10.0,
28
+ "g_loss_coef": 1.0,
29
+ "eos_coef": 0.1,
30
+ "f_loss_coef": 10.0,
31
+ "s_loss_intra_coef": 0.1,
32
+ "s_loss_inter_coef": 0.1,
33
+ "main_metric": "[email protected]",
34
+ "eval_mode": null,
35
+ "eval_bsz": 32,
36
+ "eval_epoch": 5,
37
+ "eval_init": true,
38
+ "save_interval": 5,
39
+ "resume": "/data/home/qinghonglin/univtg/results/vlp-vlp/aio_unified_mini-clip-clip-2023_05_27_00/model_e0003.ckpt",
40
+ "resume_dir": null,
41
+ "resume_all": false,
42
+ "start_epoch": null,
43
+ "no_sort_results": false,
44
+ "max_before_nms": 1000,
45
+ "max_after_nms": 10,
46
+ "conf_thd": 0.0,
47
+ "nms_thd": 0.7,
48
+ "use_cache": -1,
49
+ "max_q_l": 75,
50
+ "max_v_l": 75,
51
+ "clip_length": 2.0,
52
+ "clip_len_list": null,
53
+ "max_windows": 5,
54
+ "add_easy_negative": 1,
55
+ "easy_negative_only": 1,
56
+ "round_multiple": 1,
57
+ "train_path": [
58
+ "data/qvhighlights/metadata/qvhighlights_train.jsonl",
59
+ "data/charades/metadata/charades_train.jsonl",
60
+ "data/ego4d/metadata/nlq_train.jsonl",
61
+ "data/tacos/metadata/train.jsonl",
62
+ "data/anet/metadata/train.jsonl",
63
+ "data/didemo/metadata/train.jsonl"
64
+ ],
65
+ "eval_path": "data/qvhighlights/metadata/qvhighlights_val.jsonl",
66
+ "train_path_list": null,
67
+ "eval_path_list": null,
68
+ "feat_root_list": null,
69
+ "no_norm_vfeat": false,
70
+ "no_norm_tfeat": false,
71
+ "v_feat_dirs": [
72
+ "vid_clip"
73
+ ],
74
+ "t_feat_dir": "txt_clip",
75
+ "v_feat_dim": 512,
76
+ "t_feat_dim": 512,
77
+ "ctx_mode": "video_tef",
78
+ "v_feat_types": "clip",
79
+ "t_feat_type": "clip",
80
+ "position_embedding": "sine",
81
+ "n_input_proj": 2,
82
+ "temperature": 0.07,
83
+ "enc_layers": 4,
84
+ "sub_enc_layers": 2,
85
+ "dec_layers": 2,
86
+ "dim_feedforward": 1024,
87
+ "hidden_dim": 512,
88
+ "input_dropout": 0.5,
89
+ "dropout": 0.0,
90
+ "droppath": 0.1,
91
+ "txt_drop_ratio": 0,
92
+ "use_txt_pos": false,
93
+ "nheads": 8,
94
+ "num_queries": 10,
95
+ "pre_norm": false,
96
+ "set_cost_span": 10,
97
+ "set_cost_giou": 1,
98
+ "set_cost_class": 4,
99
+ "saliency_margin": 0.2,
100
+ "aux_loss": false,
101
+ "max_segment_num": 20,
102
+ "max_frame_num": 200,
103
+ "top_percent": 0.02,
104
+ "qfvs_vid_feature": "fps1",
105
+ "qfvs_txt_feature": "query",
106
+ "qfvs_dense_shot": -1,
107
+ "qfvs_score_ensemble": -1,
108
+ "qfvs_score_gather": -1,
109
+ "qfvs_loss_gather": -1,
110
+ "results_dir": "results/vlp-vlp/omni_mini_aio_unified__epo3_f10_b10g1_s0.1_0.1-clip-clip-2023_05_31_06"
111
+ }
run_on_video/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from run_on_video.video_extractor import vid2clip, txt2clip
run_on_video/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
run_on_video/clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
run_on_video/clip/clip.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ __all__ = ["available_models", "load", "tokenize"]
16
+ _tokenizer = _Tokenizer()
17
+
18
+ _MODELS = {
19
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
20
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
21
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
22
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
23
+ }
24
+
25
+
26
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
27
+ os.makedirs(root, exist_ok=True)
28
+ filename = os.path.basename(url)
29
+
30
+ expected_sha256 = url.split("/")[-2]
31
+ download_target = os.path.join(root, filename)
32
+
33
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
34
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
35
+
36
+ if os.path.isfile(download_target):
37
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
38
+ return download_target
39
+ else:
40
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
41
+
42
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
43
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
44
+ while True:
45
+ buffer = source.read(8192)
46
+ if not buffer:
47
+ break
48
+
49
+ output.write(buffer)
50
+ loop.update(len(buffer))
51
+
52
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
53
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
54
+
55
+ return download_target
56
+
57
+
58
+ def _transform(n_px):
59
+ return Compose([
60
+ Resize(n_px, interpolation=Image.BICUBIC),
61
+ CenterCrop(n_px),
62
+ lambda image: image.convert("RGB"),
63
+ ToTensor(),
64
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
65
+ ])
66
+
67
+
68
+ def available_models() -> List[str]:
69
+ """Returns the names of available CLIP models"""
70
+ return list(_MODELS.keys())
71
+
72
+
73
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
74
+ """Load a CLIP model
75
+
76
+ Parameters
77
+ ----------
78
+ name : str
79
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
80
+
81
+ device : Union[str, torch.device]
82
+ The device to put the loaded model
83
+
84
+ jit : bool
85
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
86
+
87
+ Returns
88
+ -------
89
+ model : torch.nn.Module
90
+ The CLIP model
91
+
92
+ preprocess : Callable[[PIL.Image], torch.Tensor]
93
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
94
+ """
95
+ if name in _MODELS:
96
+ model_path = _download(_MODELS[name])
97
+ elif os.path.isfile(name):
98
+ model_path = name
99
+ else:
100
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
101
+
102
+ try:
103
+ # loading JIT archive
104
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
105
+ state_dict = None
106
+ except RuntimeError:
107
+ # loading saved state dict
108
+ if jit:
109
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
110
+ jit = False
111
+ state_dict = torch.load(model_path, map_location="cpu")
112
+
113
+ if not jit:
114
+ model = build_model(state_dict or model.state_dict()).to(device)
115
+ if str(device) == "cpu":
116
+ model.float()
117
+ return model, _transform(model.visual.input_resolution)
118
+
119
+ # patch the device names
120
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
121
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
122
+
123
+ def patch_device(module):
124
+ graphs = [module.graph] if hasattr(module, "graph") else []
125
+ if hasattr(module, "forward1"):
126
+ graphs.append(module.forward1.graph)
127
+
128
+ for graph in graphs:
129
+ for node in graph.findAllNodes("prim::Constant"):
130
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
131
+ node.copyAttributes(device_node)
132
+
133
+ model.apply(patch_device)
134
+ patch_device(model.encode_image)
135
+ patch_device(model.encode_text)
136
+
137
+ # patch dtype to float32 on CPU
138
+ if str(device) == "cpu":
139
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
140
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
141
+ float_node = float_input.node()
142
+
143
+ def patch_float(module):
144
+ graphs = [module.graph] if hasattr(module, "graph") else []
145
+ if hasattr(module, "forward1"):
146
+ graphs.append(module.forward1.graph)
147
+
148
+ for graph in graphs:
149
+ for node in graph.findAllNodes("aten::to"):
150
+ inputs = list(node.inputs())
151
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
152
+ if inputs[i].node()["value"] == 5:
153
+ inputs[i].node().copyAttributes(float_node)
154
+
155
+ model.apply(patch_float)
156
+ patch_float(model.encode_image)
157
+ patch_float(model.encode_text)
158
+
159
+ model.float()
160
+
161
+ return model, _transform(model.input_resolution.item())
162
+
163
+
164
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, max_valid_length: int = 32) -> torch.LongTensor:
165
+ """
166
+ Returns the tokenized representation of given input string(s)
167
+
168
+ Parameters
169
+ ----------
170
+ texts : Union[str, List[str]]
171
+ An input string or a list of input strings to tokenize
172
+
173
+ context_length : int
174
+ The context length to use; all CLIP models use 77 as the context length
175
+
176
+ max_valid_length:
177
+
178
+ Returns
179
+ -------
180
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
181
+ """
182
+ if isinstance(texts, str):
183
+ texts = [texts]
184
+
185
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
186
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
187
+ all_tokens = [[sot_token] + _tokenizer.encode(text)[:max_valid_length-2] + [eot_token] for text in texts]
188
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
189
+
190
+ for i, tokens in enumerate(all_tokens):
191
+ if len(tokens) > context_length:
192
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
193
+ result[i, :len(tokens)] = torch.tensor(tokens)
194
+
195
+ return result
run_on_video/clip/model.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(OrderedDict([
35
+ ("-1", nn.AvgPool2d(stride)),
36
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
+ ("1", nn.BatchNorm2d(planes * self.expansion))
38
+ ]))
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ identity = x
42
+
43
+ out = self.relu(self.bn1(self.conv1(x)))
44
+ out = self.relu(self.bn2(self.conv2(out)))
45
+ out = self.avgpool(out)
46
+ out = self.bn3(self.conv3(out))
47
+
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+ return out
54
+
55
+
56
+ class AttentionPool2d(nn.Module):
57
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
+ super().__init__()
59
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def forward(self, x):
67
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
+ self.bn2 = nn.BatchNorm2d(width // 2)
111
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
+ self.bn3 = nn.BatchNorm2d(width)
113
+ self.avgpool = nn.AvgPool2d(2)
114
+ self.relu = nn.ReLU(inplace=True)
115
+
116
+ # residual layers
117
+ self._inplanes = width # this is a *mutable* variable used during construction
118
+ self.layer1 = self._make_layer(width, layers[0])
119
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
+
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
+
126
+ def _make_layer(self, planes, blocks, stride=1):
127
+ layers = [Bottleneck(self._inplanes, planes, stride)]
128
+
129
+ self._inplanes = planes * Bottleneck.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(Bottleneck(self._inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ def stem(x):
137
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
+ x = self.relu(bn(conv(x)))
139
+ x = self.avgpool(x)
140
+ return x
141
+
142
+ x = x.type(self.conv1.weight.dtype)
143
+ x = stem(x)
144
+ x = self.layer1(x)
145
+ x = self.layer2(x)
146
+ x = self.layer3(x)
147
+ x = self.layer4(x)
148
+ x = self.attnpool(x)
149
+
150
+ return x
151
+
152
+
153
+ class LayerNorm(nn.LayerNorm):
154
+ """Subclass torch's LayerNorm to handle fp16."""
155
+
156
+ def forward(self, x: torch.Tensor):
157
+ orig_type = x.dtype
158
+ ret = super().forward(x.type(torch.float32))
159
+ return ret.type(orig_type)
160
+
161
+
162
+ class QuickGELU(nn.Module):
163
+ def forward(self, x: torch.Tensor):
164
+ return x * torch.sigmoid(1.702 * x)
165
+
166
+
167
+ class ResidualAttentionBlock(nn.Module):
168
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
+ super().__init__()
170
+
171
+ self.attn = nn.MultiheadAttention(d_model, n_head)
172
+ self.ln_1 = LayerNorm(d_model)
173
+ self.mlp = nn.Sequential(OrderedDict([
174
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
175
+ ("gelu", QuickGELU()),
176
+ ("c_proj", nn.Linear(d_model * 4, d_model))
177
+ ]))
178
+ self.ln_2 = LayerNorm(d_model)
179
+ self.attn_mask = attn_mask
180
+
181
+ def attention(self, x: torch.Tensor):
182
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ x = x + self.attention(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisualTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 0, :])
232
+
233
+ if self.proj is not None:
234
+ x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisualTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ eos_x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return dict(last_hidden_state=x, pooler_output=eos_x)
353
+
354
+ def forward(self, image, text):
355
+ image_features = self.encode_image(image)
356
+ text_features = self.encode_text(text)
357
+
358
+ # normalized features
359
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361
+
362
+ # cosine similarity as logits
363
+ logit_scale = self.logit_scale.exp()
364
+ logits_per_image = logit_scale * image_features @ text_features.t()
365
+ logits_per_text = logit_scale * text_features @ image_features.t()
366
+
367
+ # shape = [global_batch_size, global_batch_size]
368
+ return logits_per_image, logits_per_text
369
+
370
+
371
+ def convert_weights(model: nn.Module):
372
+ """Convert applicable model parameters to fp16"""
373
+
374
+ def _convert_weights_to_fp16(l):
375
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376
+ l.weight.data = l.weight.data.half()
377
+ if l.bias is not None:
378
+ l.bias.data = l.bias.data.half()
379
+
380
+ if isinstance(l, nn.MultiheadAttention):
381
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382
+ tensor = getattr(l, attr)
383
+ if tensor is not None:
384
+ tensor.data = tensor.data.half()
385
+
386
+ for name in ["text_projection", "proj"]:
387
+ if hasattr(l, name):
388
+ attr = getattr(l, name)
389
+ if attr is not None:
390
+ attr.data = attr.data.half()
391
+
392
+ model.apply(_convert_weights_to_fp16)
393
+
394
+
395
+ def build_model(state_dict: dict):
396
+ vit = "visual.proj" in state_dict
397
+
398
+ if vit:
399
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
400
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ image_resolution = vision_patch_size * grid_size
404
+ else:
405
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406
+ vision_layers = tuple(counts)
407
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409
+ vision_patch_size = None
410
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411
+ image_resolution = output_width * 32
412
+
413
+ embed_dim = state_dict["text_projection"].shape[1]
414
+ context_length = state_dict["positional_embedding"].shape[0]
415
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
416
+ transformer_width = state_dict["ln_final.weight"].shape[0]
417
+ transformer_heads = transformer_width // 64
418
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419
+
420
+ model = CLIP(
421
+ embed_dim,
422
+ image_resolution, vision_layers, vision_width, vision_patch_size,
423
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424
+ )
425
+
426
+ for key in ["input_resolution", "context_length", "vocab_size"]:
427
+ if key in state_dict:
428
+ del state_dict[key]
429
+
430
+ convert_weights(model)
431
+ model.load_state_dict(state_dict)
432
+ return model.eval()
run_on_video/clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
run_on_video/clip_feature_extractor.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch as th
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from video_loader import VideoLoader
7
+ from torch.utils.data import DataLoader
8
+ import argparse
9
+ from preprocessing import Preprocessing
10
+ import torch.nn.functional as F
11
+ from tqdm import tqdm
12
+ import os
13
+ import sys
14
+ from feature_extractor import clip
15
+ import argparse
16
+
17
+ #################################
18
+ model_version = "ViT-B/32"
19
+ output_feat_size = 512
20
+ clip_len = 2
21
+ overwrite = True
22
+ num_decoding_thread = 4
23
+ half_precision = False
24
+
25
+ @torch.no_grad()
26
+ def extractor(vid_path, text, output_file):
27
+ dataset = VideoLoader(
28
+ vid_path,
29
+ framerate=1/clip_len,
30
+ size=224,
31
+ centercrop=True,
32
+ overwrite=overwrite,
33
+ model_version=model_version
34
+ )
35
+ n_dataset = len(dataset)
36
+ loader = DataLoader(
37
+ dataset,
38
+ batch_size=1,
39
+ shuffle=False,
40
+ num_workers=num_decoding_thread,
41
+ sampler=sampler if n_dataset > 10 else None,
42
+ )
43
+ preprocess = Preprocessing()
44
+ model, _ = clip.load(model_version, device="cuda", jit=False)
45
+
46
+ encoded_texts = clip.tokenize(text).to('cuda')
47
+ text_feature = model.encode_text(encoded_texts)['last_hidden_state']
48
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
49
+ text_feature = text_feature[0, :valid_lengths].cpu().numpy()
50
+ np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
51
+
52
+ totatl_num_frames = 0
53
+ with th.no_grad():
54
+ for k, data in enumerate(tqdm(loader)):
55
+ input_file = data['input'][0]
56
+ if os.path.isfile(output_file):
57
+ # print(f'Video {input_file} already processed.')
58
+ continue
59
+ elif not os.path.isfile(input_file):
60
+ print(f'{input_file}, does not exist.\n')
61
+ elif len(data['video'].shape) > 4:
62
+ video = data['video'].squeeze(0)
63
+ if len(video.shape) == 4:
64
+ video = preprocess(video)
65
+ n_chunk = len(video)
66
+ vid_features = th.cuda.FloatTensor(
67
+ n_chunk, output_feat_size).fill_(0)
68
+ n_iter = int(math.ceil(n_chunk))
69
+ for i in range(n_iter):
70
+ min_ind = i
71
+ max_ind = (i + 1)
72
+ video_batch = video[min_ind:max_ind].cuda()
73
+ batch_features = model.encode_image(video_batch)
74
+ vid_features[min_ind:max_ind] = batch_features
75
+ vid_features = vid_features.cpu().numpy()
76
+ if half_precision:
77
+ vid_features = vid_features.astype('float16')
78
+ totatl_num_frames += vid_features.shape[0]
79
+ # safeguard output path before saving
80
+ dirname = os.path.dirname(output_file)
81
+ if not os.path.exists(dirname):
82
+ print(f"Output directory {dirname} does not exists, creating...")
83
+ os.makedirs(dirname)
84
+ np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
85
+ else:
86
+ print(f'{input_file}, failed at ffprobe.\n')
87
+
88
+ print(f"Total number of frames: {totatl_num_frames}")
89
+
90
+ if __name__ == "__main__":
91
+ parser = argparse.ArgumentParser(description='')
92
+ parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
93
+ parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
94
+ parser.add_argument('--save_dir', type=str, default='./tmp')
95
+ args = parser.parse_args()
96
+
97
+ query = ' '.join(args.text)
98
+
99
+ print(args.vid_path)
100
+ print(query)
101
+ extractor(args.vid_path, [query], args.save_dir)
run_on_video/data_utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+ import ffmpeg
5
+ import math
6
+ from run_on_video import clip
7
+
8
+
9
+ class ClipFeatureExtractor:
10
+ def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"):
11
+ self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop)
12
+ print("Loading CLIP models")
13
+ self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False)
14
+ self.tokenizer = clip.tokenize
15
+ self.video_preprocessor = Preprocessing()
16
+ self.device = device
17
+
18
+ @torch.no_grad()
19
+ def encode_video(self, video_path: str, bsz=60):
20
+ video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3)
21
+ video_frames = self.video_preprocessor(video_frames)
22
+ n_frames = len(video_frames)
23
+ n_batch = int(math.ceil(n_frames / bsz))
24
+ video_features = []
25
+ for i in range(n_batch):
26
+ st_idx = i * bsz
27
+ ed_idx = (i+1) * bsz
28
+ _video_frames = video_frames[st_idx:ed_idx].to(self.device)
29
+ _video_features = self.clip_extractor.encode_image(_video_frames)
30
+ video_features.append(_video_features)
31
+ video_features = torch.cat(video_features, dim=0)
32
+ return video_features # (T=#frames, d) torch tensor
33
+
34
+ @torch.no_grad()
35
+ def encode_text(self, text_list, bsz=60):
36
+ n_text = len(text_list)
37
+ n_batch = int(math.ceil(n_text / bsz))
38
+ text_features = []
39
+ for i in range(n_batch):
40
+ st_idx = i * bsz
41
+ ed_idx = (i+1) * bsz
42
+ encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
43
+ output = self.clip_extractor.encode_text(encoded_texts)
44
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()
45
+ batch_last_hidden_states = output["last_hidden_state"]
46
+ for j, valid_len in enumerate(valid_lengths):
47
+ text_features.append(batch_last_hidden_states[j, :valid_len])
48
+ return text_features # List([L_j, d]) torch tensor
49
+
50
+
51
+ def convert_to_float(frac_str):
52
+ try:
53
+ return float(frac_str)
54
+ except ValueError:
55
+ try:
56
+ num, denom = frac_str.split('/')
57
+ except ValueError:
58
+ return None
59
+ try:
60
+ leading, num = num.split(' ')
61
+ except ValueError:
62
+ return float(num) / float(denom)
63
+ if float(leading) < 0:
64
+ sign_mult = -1
65
+ else:
66
+ sign_mult = 1
67
+ return float(leading) + sign_mult * (float(num) / float(denom))
68
+
69
+
70
+ class Normalize(object):
71
+
72
+ def __init__(self, mean, std):
73
+ self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1)
74
+ self.std = torch.FloatTensor(std).view(1, 3, 1, 1)
75
+
76
+ def __call__(self, tensor):
77
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
78
+ return tensor
79
+
80
+
81
+ class Preprocessing(object):
82
+
83
+ def __init__(self):
84
+ self.norm = Normalize(
85
+ mean=[0.48145466, 0.4578275, 0.40821073],
86
+ std=[0.26862954, 0.26130258, 0.27577711])
87
+
88
+ def __call__(self, tensor):
89
+ tensor = tensor / 255.0
90
+ tensor = self.norm(tensor)
91
+ return tensor
92
+
93
+
94
+ class VideoLoader:
95
+ """Pytorch video loader.
96
+ Copied and modified from:
97
+ https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py
98
+ """
99
+ def __init__(
100
+ self,
101
+ framerate=1/2,
102
+ size=224,
103
+ centercrop=True,
104
+ ):
105
+ self.centercrop = centercrop
106
+ self.size = size
107
+ self.framerate = framerate
108
+
109
+ def _get_video_info(self, video_path):
110
+ probe = ffmpeg.probe(video_path)
111
+ video_stream = next((stream for stream in probe['streams']
112
+ if stream['codec_type'] == 'video'), None)
113
+ width = int(video_stream['width'])
114
+ height = int(video_stream['height'])
115
+ fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
116
+ try:
117
+ frames_length = int(video_stream['nb_frames'])
118
+ duration = float(video_stream['duration'])
119
+ except Exception:
120
+ frames_length, duration = -1, -1
121
+ info = {"duration": duration, "frames_length": frames_length,
122
+ "fps": fps, "height": height, "width": width}
123
+ return info
124
+
125
+ def _get_output_dim(self, h, w):
126
+ if isinstance(self.size, tuple) and len(self.size) == 2:
127
+ return self.size
128
+ elif h >= w:
129
+ return int(h * self.size / w), self.size
130
+ else:
131
+ return self.size, int(w * self.size / h)
132
+
133
+ def read_video_from_file(self, video_path):
134
+ try:
135
+ info = self._get_video_info(video_path)
136
+ h, w = info["height"], info["width"]
137
+ except Exception:
138
+ print('ffprobe failed at: {}'.format(video_path))
139
+ return {'video': torch.zeros(1), 'input': video_path,
140
+ 'info': {}}
141
+ height, width = self._get_output_dim(h, w)
142
+ try:
143
+ duration = info["duration"]
144
+ fps = self.framerate
145
+ if duration > 0 and duration < 1/fps+0.1:
146
+ fps = 2/max(int(duration), 1)
147
+ print(duration, fps)
148
+ except Exception:
149
+ fps = self.framerate
150
+ cmd = (
151
+ ffmpeg
152
+ .input(video_path)
153
+ .filter('fps', fps=fps)
154
+ .filter('scale', width, height)
155
+ )
156
+ if self.centercrop:
157
+ x = int((width - self.size) / 2.0)
158
+ y = int((height - self.size) / 2.0)
159
+ cmd = cmd.crop(x, y, self.size, self.size)
160
+ out, _ = (
161
+ cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
162
+ .run(capture_stdout=True, quiet=True)
163
+ )
164
+ if self.centercrop and isinstance(self.size, int):
165
+ height, width = self.size, self.size
166
+ video = np.frombuffer(out, np.uint8).reshape(
167
+ [-1, height, width, 3])
168
+ video = torch.from_numpy(video.astype('float32'))
169
+ video = video.permute(0, 3, 1, 2)
170
+ return video
run_on_video/preprocessing.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+
3
+
4
+ class Normalize(object):
5
+
6
+ def __init__(self, mean, std):
7
+ self.mean = th.FloatTensor(mean).view(1, 3, 1, 1)
8
+ self.std = th.FloatTensor(std).view(1, 3, 1, 1)
9
+
10
+ def __call__(self, tensor):
11
+ tensor = (tensor - self.mean) / (self.std + 1e-8)
12
+ return tensor
13
+
14
+
15
+ class Preprocessing(object):
16
+
17
+ def __init__(self):
18
+ self.norm = Normalize(
19
+ mean=[0.48145466, 0.4578275, 0.40821073],
20
+ std=[0.26862954, 0.26130258, 0.27577711])
21
+
22
+ def __call__(self, tensor):
23
+ tensor = tensor / 255.0
24
+ tensor = self.norm(tensor)
25
+ return tensor
run_on_video/text_extractor.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import sys
3
+ import json
4
+ import torch
5
+ import numpy as np
6
+ from run_on_video.data_utils import ClipFeatureExtractor
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ import os
10
+
11
+ query_list = []
12
+ qid_list = []
13
+ dataset = 'charades'
14
+ split = 'test'
15
+
16
+ save_dir = f''
17
+
18
+ with open(f"data/{dataset}/metadata/{dataset}_{split}.jsonl", 'r') as f:
19
+ while True:
20
+ line = f.readline()
21
+ if not line:
22
+ break
23
+ js = json.loads(line)
24
+ query_list.append(js['query'])
25
+ qid_list.append(str(js['qid']))
26
+
27
+ # clip
28
+ feature_extractor = ClipFeatureExtractor(
29
+ framerate=1 / 2, size=224, centercrop=True,
30
+ model_name_or_path="ViT-B/32", device='cuda'
31
+ )
32
+ # pdb.set_trace()
33
+ query_feats = feature_extractor.encode_text(query_list)
34
+
35
+ for i in tqdm.tqdm(range(len(query_feats))):
36
+ np.savez(save_dir + '/' + qid_list[i], last_hidden_state=query_feats[i].cpu().numpy())
run_on_video/video_extractor.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdb
2
+ import torch as th
3
+ import math
4
+ import numpy as np
5
+ import torch
6
+ from run_on_video.video_loader import VideoLoader
7
+ from torch.utils.data import DataLoader
8
+ import argparse
9
+ from run_on_video.preprocessing import Preprocessing
10
+ import torch.nn.functional as F
11
+ from tqdm import tqdm
12
+ import os
13
+ import sys
14
+ from run_on_video import clip
15
+ import argparse
16
+
17
+ #################################
18
+ @torch.no_grad()
19
+ def vid2clip(model, vid_path, output_file,
20
+ model_version="ViT-B/32", output_feat_size=512,
21
+ clip_len=2, overwrite=True, num_decoding_thread=4, half_precision=False):
22
+ dataset = VideoLoader(
23
+ vid_path,
24
+ framerate=1/clip_len,
25
+ size=224,
26
+ centercrop=True,
27
+ overwrite=overwrite,
28
+ model_version=model_version
29
+ )
30
+ n_dataset = len(dataset)
31
+ loader = DataLoader(
32
+ dataset,
33
+ batch_size=1,
34
+ shuffle=False,
35
+ num_workers=num_decoding_thread,
36
+ sampler=None,
37
+ )
38
+ preprocess = Preprocessing()
39
+ device_id = next(model.parameters()).device
40
+
41
+ totatl_num_frames = 0
42
+ with th.no_grad():
43
+ for k, data in enumerate(tqdm(loader)):
44
+ input_file = data['input'][0]
45
+ if os.path.isfile(output_file):
46
+ # print(f'Video {input_file} already processed.')
47
+ continue
48
+ elif not os.path.isfile(input_file):
49
+ print(f'{input_file}, does not exist.\n')
50
+ elif len(data['video'].shape) > 4:
51
+ video = data['video'].squeeze(0)
52
+ if len(video.shape) == 4:
53
+ video = preprocess(video)
54
+ n_chunk = len(video)
55
+ vid_features = th.cuda.FloatTensor(
56
+ n_chunk, output_feat_size).fill_(0)
57
+ n_iter = int(math.ceil(n_chunk))
58
+ for i in range(n_iter):
59
+ min_ind = i
60
+ max_ind = (i + 1)
61
+ video_batch = video[min_ind:max_ind].to(device_id)
62
+ batch_features = model.encode_image(video_batch)
63
+ vid_features[min_ind:max_ind] = batch_features
64
+ vid_features = vid_features.cpu().numpy()
65
+ if half_precision:
66
+ vid_features = vid_features.astype('float16')
67
+ totatl_num_frames += vid_features.shape[0]
68
+ # safeguard output path before saving
69
+ dirname = os.path.dirname(output_file)
70
+ if not os.path.exists(dirname):
71
+ print(f"Output directory {dirname} does not exists, creating...")
72
+ os.makedirs(dirname)
73
+ np.savez(os.path.join(output_file, 'vid.npz'), features=vid_features)
74
+ else:
75
+ print(f'{input_file}, failed at ffprobe.\n')
76
+ print(f"Total number of frames: {totatl_num_frames}")
77
+ return vid_features
78
+
79
+ def txt2clip(model, text, output_file):
80
+ device_id = next(model.parameters()).device
81
+ encoded_texts = clip.tokenize(text).to(device_id)
82
+ text_feature = model.encode_text(encoded_texts)['last_hidden_state']
83
+ valid_lengths = (encoded_texts != 0).sum(1).tolist()[0]
84
+ text_feature = text_feature[0, :valid_lengths].detach().cpu().numpy()
85
+
86
+ np.savez(os.path.join(output_file, 'txt.npz'), features=text_feature)
87
+ return text_feature
88
+
89
+ if __name__ == "__main__":
90
+ parser = argparse.ArgumentParser(description='')
91
+ parser.add_argument('--vid_path', type=str, default='/data/home/qinghonglin/dataset/charades/videos/Charades_v1_480/0A8CF.mp4')
92
+ parser.add_argument('--text', nargs='+', type=str, default='a boy is drinking.')
93
+ parser.add_argument('--save_dir', type=str, default='./tmp')
94
+ args = parser.parse_args()