UncleWang233 commited on
Commit
08f69f6
·
1 Parent(s): a1f4877
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +4 -5
  2. app.py +331 -0
  3. data_utils/__init__.py +0 -0
  4. data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  5. data_utils/__pycache__/utils.cpython-310.pyc +0 -0
  6. data_utils/__pycache__/utils.cpython-39.pyc +0 -0
  7. data_utils/ext/synchformer/LICENSE +21 -0
  8. data_utils/ext/synchformer/__init__.py +1 -0
  9. data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc +0 -0
  10. data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc +0 -0
  11. data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc +0 -0
  12. data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc +0 -0
  13. data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc +0 -0
  14. data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc +0 -0
  15. data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc +0 -0
  16. data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc +0 -0
  17. data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc +0 -0
  18. data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc +0 -0
  19. data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc +0 -0
  20. data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc +0 -0
  21. data_utils/ext/synchformer/divided_224_16x4.yaml +84 -0
  22. data_utils/ext/synchformer/motionformer.py +400 -0
  23. data_utils/ext/synchformer/synchformer.py +55 -0
  24. data_utils/ext/synchformer/utils.py +92 -0
  25. data_utils/ext/synchformer/video_model_builder.py +277 -0
  26. data_utils/ext/synchformer/vit_helper.py +399 -0
  27. data_utils/utils.py +115 -0
  28. data_utils/v2a_utils/__init__.py +0 -0
  29. data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  30. data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc +0 -0
  31. data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc +0 -0
  32. data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc +0 -0
  33. data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc +0 -0
  34. data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc +0 -0
  35. data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc +0 -0
  36. data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc +0 -0
  37. data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc +0 -0
  38. data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc +0 -0
  39. data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc +0 -0
  40. data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc +0 -0
  41. data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc +0 -0
  42. data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc +0 -0
  43. data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc +0 -0
  44. data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc +0 -0
  45. data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc +0 -0
  46. data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc +0 -0
  47. data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc +0 -0
  48. data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc +0 -0
  49. data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc +0 -0
  50. data_utils/v2a_utils/feature_utils_224.py +182 -0
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
- title: ThinkSound
3
- emoji: 🌍
4
- colorFrom: green
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: 'demo of ThinkSound '
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Test
3
+ emoji: 📚
4
+ colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
  sdk_version: 5.35.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prefigure.prefigure import get_all_args, push_wandb_config
2
+ import json
3
+ import os
4
+ os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp"
5
+ import re
6
+ import torch
7
+ import torchaudio
8
+ # import pytorch_lightning as pl
9
+ import lightning as L
10
+ from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
11
+ from lightning.pytorch.callbacks import Callback
12
+ from lightning.pytorch.tuner import Tuner
13
+ from lightning.pytorch import seed_everything
14
+ import random
15
+ from datetime import datetime
16
+ # from think_sound.data.dataset import create_dataloader_from_config
17
+ from think_sound.data.datamodule import DataModule
18
+ from think_sound.models import create_model_from_config
19
+ from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
20
+ from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config
21
+ from think_sound.training.utils import copy_state_dict
22
+ from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
23
+ from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
24
+ from torch.utils.data import Dataset
25
+ from typing import Optional, Union
26
+ from torchvision.transforms import v2
27
+ from torio.io import StreamingMediaDecoder
28
+ from torchvision.utils import save_image
29
+ from transformers import AutoProcessor
30
+ import torch.nn.functional as F
31
+ import gradio as gr
32
+ import tempfile
33
+ import subprocess
34
+ from huggingface_hub import hf_hub_download
35
+
36
+ _CLIP_SIZE = 224
37
+ _CLIP_FPS = 8.0
38
+
39
+ _SYNC_SIZE = 224
40
+ _SYNC_FPS = 25.0
41
+
42
+ def pad_to_square(video_tensor):
43
+ if len(video_tensor.shape) != 4:
44
+ raise ValueError("Input tensor must have shape (l, c, h, w)")
45
+
46
+ l, c, h, w = video_tensor.shape
47
+ max_side = max(h, w)
48
+
49
+ pad_h = max_side - h
50
+ pad_w = max_side - w
51
+
52
+ padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
53
+
54
+ video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
55
+
56
+ return video_padded
57
+
58
+
59
+ class VGGSound(Dataset):
60
+
61
+ def __init__(
62
+ self,
63
+ sample_rate: int = 44_100,
64
+ duration_sec: float = 9.0,
65
+ audio_samples: Optional[int] = 397312,
66
+ normalize_audio: bool = False,
67
+ ):
68
+ if audio_samples is None:
69
+ self.audio_samples = int(sample_rate * duration_sec)
70
+ else:
71
+ self.audio_samples = audio_samples
72
+ effective_duration = audio_samples / sample_rate
73
+ # make sure the duration is close enough, within 15ms
74
+ assert abs(effective_duration - duration_sec) < 0.015, \
75
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
76
+
77
+ self.sample_rate = sample_rate
78
+ self.duration_sec = duration_sec
79
+
80
+ self.expected_audio_length = self.audio_samples
81
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
82
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
83
+
84
+ self.clip_transform = v2.Compose([
85
+ v2.Lambda(pad_to_square), # 先填充为正方形
86
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
87
+ v2.ToImage(),
88
+ v2.ToDtype(torch.float32, scale=True),
89
+ ])
90
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
91
+ self.sync_transform = v2.Compose([
92
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
93
+ v2.CenterCrop(_SYNC_SIZE),
94
+ v2.ToImage(),
95
+ v2.ToDtype(torch.float32, scale=True),
96
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
97
+ ])
98
+
99
+ self.resampler = {}
100
+
101
+ def sample(self, video_path,label):
102
+ video_id = video_path
103
+
104
+ reader = StreamingMediaDecoder(video_path)
105
+ reader.add_basic_video_stream(
106
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
107
+ frame_rate=_CLIP_FPS,
108
+ format='rgb24',
109
+ )
110
+ reader.add_basic_video_stream(
111
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
112
+ frame_rate=_SYNC_FPS,
113
+ format='rgb24',
114
+ )
115
+
116
+ reader.fill_buffer()
117
+ data_chunk = reader.pop_chunks()
118
+
119
+ clip_chunk = data_chunk[0]
120
+ sync_chunk = data_chunk[1]
121
+
122
+ if sync_chunk is None:
123
+ raise RuntimeError(f'Sync video returned None {video_id}')
124
+
125
+ clip_chunk = clip_chunk[:self.clip_expected_length]
126
+ # import ipdb
127
+ # ipdb.set_trace()
128
+ if clip_chunk.shape[0] != self.clip_expected_length:
129
+ current_length = clip_chunk.shape[0]
130
+ padding_needed = self.clip_expected_length - current_length
131
+
132
+ # Check that padding needed is no more than 2
133
+ assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
134
+
135
+ # If assertion passes, proceed with padding
136
+ if padding_needed > 0:
137
+ last_frame = clip_chunk[-1]
138
+ log.info(last_frame.shape)
139
+ # Repeat the last frame to reach the expected length
140
+ padding = last_frame.repeat(padding_needed, 1, 1, 1)
141
+ clip_chunk = torch.cat((clip_chunk, padding), dim=0)
142
+ # raise RuntimeError(f'CLIP video wrong length {video_id}, '
143
+ # f'expected {self.clip_expected_length}, '
144
+ # f'got {clip_chunk.shape[0]}')
145
+
146
+ # save_image(clip_chunk[0] / 255.0,'ori.png')
147
+ clip_chunk = pad_to_square(clip_chunk)
148
+
149
+ clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
150
+
151
+ sync_chunk = sync_chunk[:self.sync_expected_length]
152
+ if sync_chunk.shape[0] != self.sync_expected_length:
153
+ # padding using the last frame, but no more than 2
154
+ current_length = sync_chunk.shape[0]
155
+ last_frame = sync_chunk[-1]
156
+ # 重复最后一帧以进行填充
157
+ padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
158
+ assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
159
+ sync_chunk = torch.cat((sync_chunk, padding), dim=0)
160
+ # raise RuntimeError(f'Sync video wrong length {video_id}, '
161
+ # f'expected {self.sync_expected_length}, '
162
+ # f'got {sync_chunk.shape[0]}')
163
+
164
+ sync_chunk = self.sync_transform(sync_chunk)
165
+ # assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
166
+ # and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
167
+ data = {
168
+ 'id': video_id,
169
+ 'caption': label,
170
+ # 'audio': audio_chunk,
171
+ 'clip_video': clip_chunk,
172
+ 'sync_video': sync_chunk,
173
+ }
174
+
175
+ return data
176
+
177
+ # 检查设备
178
+ if torch.cuda.is_available():
179
+ device = 'cuda'
180
+ extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
181
+ else:
182
+ device = 'cpu'
183
+ extra_device = 'cpu'
184
+
185
+ vae_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=3-step=100000.ckpt",repo_type="dataset")
186
+ synchformer_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="synchformer_state_dict.pth",repo_type="dataset")
187
+ feature_extractor = FeaturesUtils(
188
+ vae_ckpt=vae_ckpt,
189
+ vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json',
190
+ enable_conditions=True,
191
+ synchformer_ckpt=synchformer_ckpt
192
+ ).eval().to(extra_device)
193
+
194
+ preprocesser = VGGSound()
195
+
196
+ args = get_all_args()
197
+
198
+ seed = 10086
199
+
200
+ seed_everything(seed, workers=True)
201
+
202
+
203
+ #Get JSON config from args.model_config
204
+ with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f:
205
+ model_config = json.load(f)
206
+
207
+ model = create_model_from_config(model_config)
208
+
209
+ ## speed by torch.compile
210
+ if args.compile:
211
+ model = torch.compile(model)
212
+
213
+ if args.pretrained_ckpt_path:
214
+ copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
215
+
216
+ if args.remove_pretransform_weight_norm == "pre_load":
217
+ remove_weight_norm_from_model(model.pretransform)
218
+
219
+
220
+ load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
221
+ # new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
222
+ model.pretransform.load_state_dict(load_vae_state)
223
+
224
+ # Remove weight_norm from the pretransform if specified
225
+ if args.remove_pretransform_weight_norm == "post_load":
226
+ remove_weight_norm_from_model(model.pretransform)
227
+ ckpt_path = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=10-step=68000.ckpt",repo_type="dataset")
228
+ training_wrapper = create_training_wrapper_from_config(model_config, model)
229
+ # 加载模型权重时根据设备选择map_location
230
+ if device == 'cuda':
231
+ training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
232
+ else:
233
+ training_wrapper.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'])
234
+
235
+ def get_audio(video_path, caption):
236
+ # 允许caption为空
237
+ if caption is None:
238
+ caption = ''
239
+ timer = Timer(duration="00:15:00:00")
240
+ data = preprocesser.sample(video_path, caption)
241
+
242
+ preprocessed_data = {}
243
+ metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption'])
244
+ preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
245
+ preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
246
+
247
+ t5_features = feature_extractor.encode_t5_text(data['caption'])
248
+ preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
249
+
250
+ clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
251
+ preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0)
252
+
253
+ sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device))
254
+ preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0)
255
+ preprocessed_data['video_exist'] = torch.tensor(True)
256
+
257
+ metadata = [preprocessed_data]
258
+
259
+ batch_size = 1
260
+ length = 194
261
+ with torch.amp.autocast(device):
262
+ conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
263
+
264
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
265
+ conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
266
+ conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
267
+
268
+ cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
269
+ noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
270
+ with torch.amp.autocast(device):
271
+ model = training_wrapper.diffusion.model
272
+ if training_wrapper.diffusion_objective == "v":
273
+ fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
274
+ elif training_wrapper.diffusion_objective == "rectified_flow":
275
+ import time
276
+ start_time = time.time()
277
+ fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
278
+ end_time = time.time()
279
+ execution_time = end_time - start_time
280
+ print(f"执行时间: {execution_time:.2f} 秒")
281
+ if training_wrapper.diffusion.pretransform is not None:
282
+ fakes = training_wrapper.diffusion.pretransform.decode(fakes)
283
+
284
+ audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
285
+ # 保存临时音频文件
286
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
287
+ torchaudio.save(tmp_audio.name, audios[0], 44100)
288
+ audio_path = tmp_audio.name
289
+ return audio_path
290
+
291
+ # 合成新视频:用ffmpeg将音频与原视频合成
292
+
293
+ def synthesize_video_with_audio(video_file, caption):
294
+ # 允许caption为空
295
+ if caption is None:
296
+ caption = ''
297
+ audio_path = get_audio(video_file, caption)
298
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
299
+ output_video_path = tmp_video.name
300
+ # ffmpeg命令:用新音频替换原视频音轨
301
+ cmd = [
302
+ 'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
303
+ '-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
304
+ '-shortest', output_video_path
305
+ ]
306
+ subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
307
+ return output_video_path
308
+
309
+ # Gradio界面
310
+ with gr.Blocks() as demo:
311
+ gr.Markdown("# ThinkSound\nupload video and caption(optional), and get video with audio!")
312
+ with gr.Row():
313
+ video_input = gr.Video(label="upload video")
314
+ caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1)
315
+ output_video = gr.Video(label="output video")
316
+ btn = gr.Button("start synthesize")
317
+ btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video)
318
+
319
+ gr.Examples(
320
+ examples=[
321
+ ["./examples/1_mute.mp4", "Playing Trumpet"],
322
+ ["./examples/2_mute.mp4", "Axe striking"],
323
+ ["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier"],
324
+ ["./examples/4_mute.mp4", "train passing by"],
325
+ ["./examples/5_mute.mp4", "Lighting Firecrackers"]
326
+ ],
327
+ inputs=[video_input, caption_input],
328
+ )
329
+
330
+ demo.launch(share=True)
331
+
data_utils/__init__.py ADDED
File without changes
data_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (149 Bytes). View file
 
data_utils/__pycache__/utils.cpython-310.pyc ADDED
Binary file (4.56 kB). View file
 
data_utils/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.56 kB). View file
 
data_utils/ext/synchformer/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Vladimir Iashin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data_utils/ext/synchformer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from data_utils.ext.synchformer.synchformer import Synchformer
data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (243 Bytes). View file
 
data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (241 Bytes). View file
 
data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc ADDED
Binary file (1.91 kB). View file
 
data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc ADDED
Binary file (1.9 kB). View file
 
data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.97 kB). View file
 
data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.78 kB). View file
 
data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc ADDED
Binary file (5.84 kB). View file
 
data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc ADDED
Binary file (5.8 kB). View file
 
data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
data_utils/ext/synchformer/divided_224_16x4.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TRAIN:
2
+ ENABLE: True
3
+ DATASET: Ssv2
4
+ BATCH_SIZE: 32
5
+ EVAL_PERIOD: 5
6
+ CHECKPOINT_PERIOD: 5
7
+ AUTO_RESUME: True
8
+ CHECKPOINT_EPOCH_RESET: True
9
+ CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
10
+ DATA:
11
+ NUM_FRAMES: 16
12
+ SAMPLING_RATE: 4
13
+ TRAIN_JITTER_SCALES: [256, 320]
14
+ TRAIN_CROP_SIZE: 224
15
+ TEST_CROP_SIZE: 224
16
+ INPUT_CHANNEL_NUM: [3]
17
+ MEAN: [0.5, 0.5, 0.5]
18
+ STD: [0.5, 0.5, 0.5]
19
+ PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
20
+ PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
21
+ INV_UNIFORM_SAMPLE: True
22
+ RANDOM_FLIP: False
23
+ REVERSE_INPUT_CHANNEL: True
24
+ USE_RAND_AUGMENT: True
25
+ RE_PROB: 0.0
26
+ USE_REPEATED_AUG: False
27
+ USE_RANDOM_RESIZE_CROPS: False
28
+ COLORJITTER: False
29
+ GRAYSCALE: False
30
+ GAUSSIAN: False
31
+ SOLVER:
32
+ BASE_LR: 1e-4
33
+ LR_POLICY: steps_with_relative_lrs
34
+ LRS: [1, 0.1, 0.01]
35
+ STEPS: [0, 20, 30]
36
+ MAX_EPOCH: 35
37
+ MOMENTUM: 0.9
38
+ WEIGHT_DECAY: 5e-2
39
+ WARMUP_EPOCHS: 0.0
40
+ OPTIMIZING_METHOD: adamw
41
+ USE_MIXED_PRECISION: True
42
+ SMOOTHING: 0.2
43
+ SLOWFAST:
44
+ ALPHA: 8
45
+ VIT:
46
+ PATCH_SIZE: 16
47
+ PATCH_SIZE_TEMP: 2
48
+ CHANNELS: 3
49
+ EMBED_DIM: 768
50
+ DEPTH: 12
51
+ NUM_HEADS: 12
52
+ MLP_RATIO: 4
53
+ QKV_BIAS: True
54
+ VIDEO_INPUT: True
55
+ TEMPORAL_RESOLUTION: 8
56
+ USE_MLP: True
57
+ DROP: 0.0
58
+ POS_DROPOUT: 0.0
59
+ DROP_PATH: 0.2
60
+ IM_PRETRAINED: True
61
+ HEAD_DROPOUT: 0.0
62
+ HEAD_ACT: tanh
63
+ PRETRAINED_WEIGHTS: vit_1k
64
+ ATTN_LAYER: divided
65
+ MODEL:
66
+ NUM_CLASSES: 174
67
+ ARCH: slow
68
+ MODEL_NAME: VisionTransformer
69
+ LOSS_FUNC: cross_entropy
70
+ TEST:
71
+ ENABLE: True
72
+ DATASET: Ssv2
73
+ BATCH_SIZE: 64
74
+ NUM_ENSEMBLE_VIEWS: 1
75
+ NUM_SPATIAL_CROPS: 3
76
+ DATA_LOADER:
77
+ NUM_WORKERS: 4
78
+ PIN_MEMORY: True
79
+ NUM_GPUS: 8
80
+ NUM_SHARDS: 4
81
+ RNG_SEED: 0
82
+ OUTPUT_DIR: .
83
+ TENSORBOARD:
84
+ ENABLE: True
data_utils/ext/synchformer/motionformer.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import einops
5
+ import torch
6
+ from omegaconf import OmegaConf
7
+ from timm.layers import trunc_normal_
8
+ from torch import nn
9
+
10
+ from data_utils.ext.synchformer.utils import check_if_file_exists_else_download
11
+ from data_utils.ext.synchformer.video_model_builder import VisionTransformer
12
+
13
+ FILE2URL = {
14
+ # cfg
15
+ 'motionformer_224_16x4.yaml':
16
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
17
+ 'joint_224_16x4.yaml':
18
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
19
+ 'divided_224_16x4.yaml':
20
+ 'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
21
+ # ckpt
22
+ 'ssv2_motionformer_224_16x4.pyth':
23
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
24
+ 'ssv2_joint_224_16x4.pyth':
25
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
26
+ 'ssv2_divided_224_16x4.pyth':
27
+ 'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
28
+ }
29
+
30
+
31
+ class MotionFormer(VisionTransformer):
32
+ ''' This class serves three puposes:
33
+ 1. Renames the class to MotionFormer.
34
+ 2. Downloads the cfg from the original repo and patches it if needed.
35
+ 3. Takes care of feature extraction by redefining .forward()
36
+ - if `extract_features=True` and `factorize_space_time=False`,
37
+ the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
38
+ - if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
39
+ and spatial and temporal transformer encoder layers are used.
40
+ - if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
41
+ the output is of shape (B, D) and spatial and temporal transformer encoder layers
42
+ are used as well as the global representation is extracted from segments (extra pos emb
43
+ is added).
44
+ '''
45
+
46
+ def __init__(
47
+ self,
48
+ extract_features: bool = False,
49
+ ckpt_path: str = None,
50
+ factorize_space_time: bool = None,
51
+ agg_space_module: str = None,
52
+ agg_time_module: str = None,
53
+ add_global_repr: bool = True,
54
+ agg_segments_module: str = None,
55
+ max_segments: int = None,
56
+ ):
57
+ self.extract_features = extract_features
58
+ self.ckpt_path = ckpt_path
59
+ self.factorize_space_time = factorize_space_time
60
+
61
+ if self.ckpt_path is not None:
62
+ check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
63
+ ckpt = torch.load(self.ckpt_path, map_location='cpu')
64
+ mformer_ckpt2cfg = {
65
+ 'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
66
+ 'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
67
+ 'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
68
+ }
69
+ # init from motionformer ckpt or from our Stage I ckpt
70
+ # depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
71
+ # load the state dict differently
72
+ was_pt_on_avclip = self.ckpt_path.endswith(
73
+ '.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
74
+ if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
75
+ cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
76
+ elif was_pt_on_avclip:
77
+ # TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
78
+ s1_cfg = ckpt.get('args', None) # Stage I cfg
79
+ if s1_cfg is not None:
80
+ s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
81
+ # if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
82
+ if s1_vfeat_extractor_ckpt_path is not None:
83
+ cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
84
+ else:
85
+ cfg_fname = 'divided_224_16x4.yaml'
86
+ else:
87
+ cfg_fname = 'divided_224_16x4.yaml'
88
+ else:
89
+ raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
90
+ else:
91
+ was_pt_on_avclip = False
92
+ cfg_fname = 'divided_224_16x4.yaml'
93
+ # logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
94
+
95
+ if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
96
+ pos_emb_type = 'separate'
97
+ elif cfg_fname == 'joint_224_16x4.yaml':
98
+ pos_emb_type = 'joint'
99
+
100
+ self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
101
+
102
+ check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
103
+ mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
104
+ logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
105
+
106
+ # patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
107
+ mformer_cfg.VIT.ATTN_DROPOUT = 0.0
108
+ mformer_cfg.VIT.POS_EMBED = pos_emb_type
109
+ mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
110
+ mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
111
+ mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
112
+
113
+ # finally init VisionTransformer with the cfg
114
+ super().__init__(mformer_cfg)
115
+
116
+ # load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
117
+ if (self.ckpt_path is not None) and (not was_pt_on_avclip):
118
+ _ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
119
+ if len(_ckpt_load_status.missing_keys) > 0 or len(
120
+ _ckpt_load_status.unexpected_keys) > 0:
121
+ logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
122
+ f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
123
+ f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
124
+ else:
125
+ logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
126
+
127
+ if self.extract_features:
128
+ assert isinstance(self.norm,
129
+ nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
130
+ # pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
131
+ self.pre_logits = nn.Identity()
132
+ # we don't need the classification head (saving memory)
133
+ self.head = nn.Identity()
134
+ self.head_drop = nn.Identity()
135
+ # avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
136
+ transf_enc_layer_kwargs = dict(
137
+ d_model=self.embed_dim,
138
+ nhead=self.num_heads,
139
+ activation=nn.GELU(),
140
+ batch_first=True,
141
+ dim_feedforward=self.mlp_ratio * self.embed_dim,
142
+ dropout=self.drop_rate,
143
+ layer_norm_eps=1e-6,
144
+ norm_first=True,
145
+ )
146
+ # define adapters if needed
147
+ if self.factorize_space_time:
148
+ if agg_space_module == 'TransformerEncoderLayer':
149
+ self.spatial_attn_agg = SpatialTransformerEncoderLayer(
150
+ **transf_enc_layer_kwargs)
151
+ elif agg_space_module == 'AveragePooling':
152
+ self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
153
+ then_permute_pattern='BS D t -> BS t D')
154
+ if agg_time_module == 'TransformerEncoderLayer':
155
+ self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
156
+ elif agg_time_module == 'AveragePooling':
157
+ self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
158
+ elif 'Identity' in agg_time_module:
159
+ self.temp_attn_agg = nn.Identity()
160
+ # define a global aggregation layer (aggregarate over segments)
161
+ self.add_global_repr = add_global_repr
162
+ if add_global_repr:
163
+ if agg_segments_module == 'TransformerEncoderLayer':
164
+ # we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
165
+ # we need to add pos emb (PE) because previously we added the same PE for each segment
166
+ pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
167
+ self.global_attn_agg = TemporalTransformerEncoderLayer(
168
+ add_pos_emb=True,
169
+ pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
170
+ pos_max_len=pos_max_len,
171
+ **transf_enc_layer_kwargs)
172
+ elif agg_segments_module == 'AveragePooling':
173
+ self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
174
+
175
+ if was_pt_on_avclip:
176
+ # we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
177
+ # and keep only the state_dict of the feat extractor
178
+ ckpt_weights = dict()
179
+ for k, v in ckpt['state_dict'].items():
180
+ if k.startswith(('module.v_encoder.', 'v_encoder.')):
181
+ k = k.replace('module.', '').replace('v_encoder.', '')
182
+ ckpt_weights[k] = v
183
+ _load_status = self.load_state_dict(ckpt_weights, strict=False)
184
+ if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
185
+ logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
186
+ f'Missing keys ({len(_load_status.missing_keys)}): ' \
187
+ f'{_load_status.missing_keys}, \n' \
188
+ f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
189
+ f'{_load_status.unexpected_keys} \n' \
190
+ f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
191
+ else:
192
+ logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
193
+
194
+ # patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
195
+ # but it used to calculate the number of patches, so we need to set keep it
196
+ self.patch_embed.requires_grad_(False)
197
+
198
+ def forward(self, x):
199
+ '''
200
+ x is of shape (B, S, C, T, H, W) where S is the number of segments.
201
+ '''
202
+ # Batch, Segments, Channels, T=frames, Height, Width
203
+ B, S, C, T, H, W = x.shape
204
+ # Motionformer expects a tensor of shape (1, B, C, T, H, W).
205
+ # The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
206
+ # see `video_model_builder.video_input`.
207
+ # x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
208
+
209
+ orig_shape = (B, S, C, T, H, W)
210
+ x = x.view(B * S, C, T, H, W) # flatten batch and segments
211
+ x = self.forward_segments(x, orig_shape=orig_shape)
212
+ # unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
213
+ x = x.view(B, S, *x.shape[1:])
214
+ # x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
215
+
216
+ return x # x is (B, S, ...)
217
+
218
+ def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
219
+ '''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
220
+ x, x_mask = self.forward_features(x)
221
+
222
+ assert self.extract_features
223
+
224
+ # (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
225
+ x = x[:,
226
+ 1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
227
+ x = self.norm(x)
228
+ x = self.pre_logits(x)
229
+ if self.factorize_space_time:
230
+ x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
231
+
232
+ x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
233
+ x = self.temp_attn_agg(
234
+ x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
235
+
236
+ return x
237
+
238
+ def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
239
+ '''
240
+ feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
241
+ Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
242
+ From `self.patch_embed_3d`, it follows that we could reshape feats with:
243
+ `feats.transpose(1, 2).view(B*S, D, t, h, w)`
244
+ '''
245
+ B, S, C, T, H, W = orig_shape
246
+ D = self.embed_dim
247
+
248
+ # num patches in each dimension
249
+ t = T // self.patch_embed_3d.z_block_size
250
+ h = self.patch_embed_3d.height
251
+ w = self.patch_embed_3d.width
252
+
253
+ feats = feats.permute(0, 2, 1) # (B*S, D, T)
254
+ feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
255
+
256
+ return feats
257
+
258
+
259
+ class BaseEncoderLayer(nn.TransformerEncoderLayer):
260
+ '''
261
+ This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
262
+ to the sequence and outputs the CLS token's representation.
263
+ This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
264
+ and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
265
+ We also, optionally, add a positional embedding to the input sequence which
266
+ allows to reuse it for global aggregation (of segments) for both streams.
267
+ '''
268
+
269
+ def __init__(self,
270
+ add_pos_emb: bool = False,
271
+ pos_emb_drop: float = None,
272
+ pos_max_len: int = None,
273
+ *args_transformer_enc,
274
+ **kwargs_transformer_enc):
275
+ super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
276
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
277
+ trunc_normal_(self.cls_token, std=.02)
278
+
279
+ # add positional embedding
280
+ self.add_pos_emb = add_pos_emb
281
+ if add_pos_emb:
282
+ self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
283
+ self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
284
+ self.pos_drop = nn.Dropout(pos_emb_drop)
285
+ trunc_normal_(self.pos_emb, std=.02)
286
+
287
+ self.apply(self._init_weights)
288
+
289
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
290
+ ''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
291
+ batch_dim = x.shape[0]
292
+
293
+ # add CLS token
294
+ cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
295
+ x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
296
+ if x_mask is not None:
297
+ cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
298
+ device=x_mask.device) # 1=keep; 0=mask
299
+ x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
300
+ B, N = x_mask_w_cls.shape
301
+ # torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
302
+ x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
303
+ .expand(-1, self.self_attn.num_heads, N, -1)\
304
+ .reshape(B * self.self_attn.num_heads, N, N)
305
+ assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
306
+ x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
307
+ else:
308
+ x_mask_w_cls = None
309
+
310
+ # add positional embedding
311
+ if self.add_pos_emb:
312
+ seq_len = x.shape[
313
+ 1] # (don't even think about moving it before the CLS token concatenation)
314
+ assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
315
+ x = x + self.pos_emb[:, :seq_len, :]
316
+ x = self.pos_drop(x)
317
+
318
+ # apply encoder layer (calls nn.TransformerEncoderLayer.forward);
319
+ x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
320
+
321
+ # CLS token is expected to hold spatial information for each frame
322
+ x = x[:, 0, :] # (batch_dim, D)
323
+
324
+ return x
325
+
326
+ def _init_weights(self, m):
327
+ if isinstance(m, nn.Linear):
328
+ trunc_normal_(m.weight, std=.02)
329
+ if isinstance(m, nn.Linear) and m.bias is not None:
330
+ nn.init.constant_(m.bias, 0)
331
+ elif isinstance(m, nn.LayerNorm):
332
+ nn.init.constant_(m.bias, 0)
333
+ nn.init.constant_(m.weight, 1.0)
334
+
335
+ @torch.jit.ignore
336
+ def no_weight_decay(self):
337
+ return {'cls_token', 'pos_emb'}
338
+
339
+
340
+ class SpatialTransformerEncoderLayer(BaseEncoderLayer):
341
+ ''' Aggregates spatial dimensions by applying attention individually to each frame. '''
342
+
343
+ def __init__(self, *args, **kwargs):
344
+ super().__init__(*args, **kwargs)
345
+
346
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
347
+ ''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
348
+ if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
349
+ Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
350
+ BS, D, t, h, w = x.shape
351
+
352
+ # time as a batch dimension and flatten spatial dimensions as sequence
353
+ x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
354
+ # similar to mask
355
+ if x_mask is not None:
356
+ x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
357
+
358
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
359
+ x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
360
+
361
+ # reshape back to (B*S, t, D)
362
+ x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
363
+
364
+ # (B*S, t, D)
365
+ return x
366
+
367
+
368
+ class TemporalTransformerEncoderLayer(BaseEncoderLayer):
369
+ ''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
370
+ in both streams. '''
371
+
372
+ def __init__(self, *args, **kwargs):
373
+ super().__init__(*args, **kwargs)
374
+
375
+ def forward(self, x):
376
+ ''' x is of shape (B*S, t, D) where S is the number of segments.
377
+ Returns a tensor of shape (B*S, D) pooling temporal information. '''
378
+ BS, t, D = x.shape
379
+
380
+ # apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
381
+ x = super().forward(x) # (B*S, D)
382
+
383
+ return x # (B*S, D)
384
+
385
+
386
+ class AveragePooling(nn.Module):
387
+
388
+ def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
389
+ ''' patterns are e.g. "bs t d -> bs d" '''
390
+ super().__init__()
391
+ # TODO: need to register them as buffers (but fails because these are strings)
392
+ self.reduce_fn = 'mean'
393
+ self.avg_pattern = avg_pattern
394
+ self.then_permute_pattern = then_permute_pattern
395
+
396
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
397
+ x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
398
+ if self.then_permute_pattern is not None:
399
+ x = einops.rearrange(x, self.then_permute_pattern)
400
+ return x
data_utils/ext/synchformer/synchformer.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Mapping
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from data_utils.ext.synchformer.motionformer import MotionFormer
8
+
9
+
10
+ class Synchformer(nn.Module):
11
+
12
+ def __init__(self):
13
+ super().__init__()
14
+
15
+ self.vfeat_extractor = MotionFormer(extract_features=True,
16
+ factorize_space_time=True,
17
+ agg_space_module='TransformerEncoderLayer',
18
+ agg_time_module='torch.nn.Identity',
19
+ add_global_repr=False)
20
+
21
+ # self.vfeat_extractor = instantiate_from_config(vfeat_extractor)
22
+ # self.afeat_extractor = instantiate_from_config(afeat_extractor)
23
+ # # bridging the s3d latent dim (1024) into what is specified in the config
24
+ # # to match e.g. the transformer dim
25
+ # self.vproj = instantiate_from_config(vproj)
26
+ # self.aproj = instantiate_from_config(aproj)
27
+ # self.transformer = instantiate_from_config(transformer)
28
+
29
+ def forward(self, vis):
30
+ B, S, Tv, C, H, W = vis.shape
31
+ vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
32
+ # feat extractors return a tuple of segment-level and global features (ignored for sync)
33
+ # (B, S, tv, D), e.g. (B, 7, 8, 768)
34
+ vis = self.vfeat_extractor(vis)
35
+ return vis
36
+
37
+ def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
38
+ # discard all entries except vfeat_extractor
39
+ sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
40
+
41
+ return super().load_state_dict(sd, strict)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ model = Synchformer().cuda().eval()
46
+ sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
47
+ model.load_state_dict(sd)
48
+
49
+ vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
50
+ features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
51
+ print(features.shape)
52
+
53
+ # extract and save the state dict only
54
+ # sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
55
+ # torch.save(sd, './ext_weights/synchformer_state_dict.pth')
data_utils/ext/synchformer/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import md5
2
+ from pathlib import Path
3
+
4
+ import requests
5
+ from tqdm import tqdm
6
+
7
+ PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a'
8
+ FNAME2LINK = {
9
+ # S3: Synchability: AudioSet (run 2)
10
+ '24-01-22T20-34-52.pt':
11
+ f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt',
12
+ 'cfg-24-01-22T20-34-52.yaml':
13
+ f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml',
14
+ # S2: Synchformer: AudioSet (run 2)
15
+ '24-01-04T16-39-21.pt':
16
+ f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt',
17
+ 'cfg-24-01-04T16-39-21.yaml':
18
+ f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml',
19
+ # S2: Synchformer: AudioSet (run 1)
20
+ '23-08-28T11-23-23.pt':
21
+ f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt',
22
+ 'cfg-23-08-28T11-23-23.yaml':
23
+ f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml',
24
+ # S2: Synchformer: LRS3 (run 2)
25
+ '23-12-23T18-33-57.pt':
26
+ f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt',
27
+ 'cfg-23-12-23T18-33-57.yaml':
28
+ f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml',
29
+ # S2: Synchformer: VGS (run 2)
30
+ '24-01-02T10-00-53.pt':
31
+ f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt',
32
+ 'cfg-24-01-02T10-00-53.yaml':
33
+ f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml',
34
+ # SparseSync: ft VGGSound-Full
35
+ '22-09-21T21-00-52.pt':
36
+ f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt',
37
+ 'cfg-22-09-21T21-00-52.yaml':
38
+ f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml',
39
+ # SparseSync: ft VGGSound-Sparse
40
+ '22-07-28T15-49-45.pt':
41
+ f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt',
42
+ 'cfg-22-07-28T15-49-45.yaml':
43
+ f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml',
44
+ # SparseSync: only pt on LRS3
45
+ '22-07-13T22-25-49.pt':
46
+ f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt',
47
+ 'cfg-22-07-13T22-25-49.yaml':
48
+ f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml',
49
+ # SparseSync: feature extractors
50
+ 'ResNetAudio-22-08-04T09-51-04.pt':
51
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s
52
+ 'ResNetAudio-22-08-03T23-14-49.pt':
53
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s
54
+ 'ResNetAudio-22-08-03T23-14-28.pt':
55
+ f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s
56
+ 'ResNetAudio-22-06-24T08-10-33.pt':
57
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s
58
+ 'ResNetAudio-22-06-24T17-31-07.pt':
59
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s
60
+ 'ResNetAudio-22-06-24T23-57-11.pt':
61
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s
62
+ 'ResNetAudio-22-06-25T04-35-42.pt':
63
+ f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s
64
+ }
65
+
66
+
67
+ def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
68
+ '''Checks if file exists, if not downloads it from the link to the path'''
69
+ path = Path(path)
70
+ if not path.exists():
71
+ path.parent.mkdir(exist_ok=True, parents=True)
72
+ link = fname2link.get(path.name, None)
73
+ if link is None:
74
+ raise ValueError(f'Cant find the checkpoint file: {path}.',
75
+ f'Please download it manually and ensure the path exists.')
76
+ with requests.get(fname2link[path.name], stream=True) as r:
77
+ total_size = int(r.headers.get('content-length', 0))
78
+ with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
79
+ with open(path, 'wb') as f:
80
+ for data in r.iter_content(chunk_size=chunk_size):
81
+ if data:
82
+ f.write(data)
83
+ pbar.update(chunk_size)
84
+
85
+
86
+ def get_md5sum(path):
87
+ hash_md5 = md5()
88
+ with open(path, 'rb') as f:
89
+ for chunk in iter(lambda: f.read(4096 * 8), b''):
90
+ hash_md5.update(chunk)
91
+ md5sum = hash_md5.hexdigest()
92
+ return md5sum
data_utils/ext/synchformer/video_model_builder.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+ # Copyright 2020 Ross Wightman
4
+ # Modified Model definition
5
+
6
+ from collections import OrderedDict
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from timm.layers import trunc_normal_
12
+
13
+ from data_utils.ext.synchformer import vit_helper
14
+
15
+
16
+ class VisionTransformer(nn.Module):
17
+ """ Vision Transformer with support for patch or hybrid CNN input stage """
18
+
19
+ def __init__(self, cfg):
20
+ super().__init__()
21
+ self.img_size = cfg.DATA.TRAIN_CROP_SIZE
22
+ self.patch_size = cfg.VIT.PATCH_SIZE
23
+ self.in_chans = cfg.VIT.CHANNELS
24
+ if cfg.TRAIN.DATASET == "Epickitchens":
25
+ self.num_classes = [97, 300]
26
+ else:
27
+ self.num_classes = cfg.MODEL.NUM_CLASSES
28
+ self.embed_dim = cfg.VIT.EMBED_DIM
29
+ self.depth = cfg.VIT.DEPTH
30
+ self.num_heads = cfg.VIT.NUM_HEADS
31
+ self.mlp_ratio = cfg.VIT.MLP_RATIO
32
+ self.qkv_bias = cfg.VIT.QKV_BIAS
33
+ self.drop_rate = cfg.VIT.DROP
34
+ self.drop_path_rate = cfg.VIT.DROP_PATH
35
+ self.head_dropout = cfg.VIT.HEAD_DROPOUT
36
+ self.video_input = cfg.VIT.VIDEO_INPUT
37
+ self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
38
+ self.use_mlp = cfg.VIT.USE_MLP
39
+ self.num_features = self.embed_dim
40
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
41
+ self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
42
+ self.head_act = cfg.VIT.HEAD_ACT
43
+ self.cfg = cfg
44
+
45
+ # Patch Embedding
46
+ self.patch_embed = vit_helper.PatchEmbed(img_size=224,
47
+ patch_size=self.patch_size,
48
+ in_chans=self.in_chans,
49
+ embed_dim=self.embed_dim)
50
+
51
+ # 3D Patch Embedding
52
+ self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
53
+ temporal_resolution=self.temporal_resolution,
54
+ patch_size=self.patch_size,
55
+ in_chans=self.in_chans,
56
+ embed_dim=self.embed_dim,
57
+ z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
58
+ self.patch_embed_3d.proj.weight.data = torch.zeros_like(
59
+ self.patch_embed_3d.proj.weight.data)
60
+
61
+ # Number of patches
62
+ if self.video_input:
63
+ num_patches = self.patch_embed.num_patches * self.temporal_resolution
64
+ else:
65
+ num_patches = self.patch_embed.num_patches
66
+ self.num_patches = num_patches
67
+
68
+ # CLS token
69
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
70
+ trunc_normal_(self.cls_token, std=.02)
71
+
72
+ # Positional embedding
73
+ self.pos_embed = nn.Parameter(
74
+ torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
75
+ self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
76
+ trunc_normal_(self.pos_embed, std=.02)
77
+
78
+ if self.cfg.VIT.POS_EMBED == "joint":
79
+ self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
80
+ trunc_normal_(self.st_embed, std=.02)
81
+ elif self.cfg.VIT.POS_EMBED == "separate":
82
+ self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
83
+
84
+ # Layer Blocks
85
+ dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
86
+ if self.cfg.VIT.ATTN_LAYER == "divided":
87
+ self.blocks = nn.ModuleList([
88
+ vit_helper.DividedSpaceTimeBlock(
89
+ attn_type=cfg.VIT.ATTN_LAYER,
90
+ dim=self.embed_dim,
91
+ num_heads=self.num_heads,
92
+ mlp_ratio=self.mlp_ratio,
93
+ qkv_bias=self.qkv_bias,
94
+ drop=self.drop_rate,
95
+ attn_drop=self.attn_drop_rate,
96
+ drop_path=dpr[i],
97
+ norm_layer=norm_layer,
98
+ ) for i in range(self.depth)
99
+ ])
100
+ else:
101
+ self.blocks = nn.ModuleList([
102
+ vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
103
+ dim=self.embed_dim,
104
+ num_heads=self.num_heads,
105
+ mlp_ratio=self.mlp_ratio,
106
+ qkv_bias=self.qkv_bias,
107
+ drop=self.drop_rate,
108
+ attn_drop=self.attn_drop_rate,
109
+ drop_path=dpr[i],
110
+ norm_layer=norm_layer,
111
+ use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
112
+ for i in range(self.depth)
113
+ ])
114
+ self.norm = norm_layer(self.embed_dim)
115
+
116
+ # MLP head
117
+ if self.use_mlp:
118
+ hidden_dim = self.embed_dim
119
+ if self.head_act == 'tanh':
120
+ # logging.info("Using TanH activation in MLP")
121
+ act = nn.Tanh()
122
+ elif self.head_act == 'gelu':
123
+ # logging.info("Using GELU activation in MLP")
124
+ act = nn.GELU()
125
+ else:
126
+ # logging.info("Using ReLU activation in MLP")
127
+ act = nn.ReLU()
128
+ self.pre_logits = nn.Sequential(
129
+ OrderedDict([
130
+ ('fc', nn.Linear(self.embed_dim, hidden_dim)),
131
+ ('act', act),
132
+ ]))
133
+ else:
134
+ self.pre_logits = nn.Identity()
135
+
136
+ # Classifier Head
137
+ self.head_drop = nn.Dropout(p=self.head_dropout)
138
+ if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
139
+ for a, i in enumerate(range(len(self.num_classes))):
140
+ setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
141
+ else:
142
+ self.head = nn.Linear(self.embed_dim,
143
+ self.num_classes) if self.num_classes > 0 else nn.Identity()
144
+
145
+ # Initialize weights
146
+ self.apply(self._init_weights)
147
+
148
+ def _init_weights(self, m):
149
+ if isinstance(m, nn.Linear):
150
+ trunc_normal_(m.weight, std=.02)
151
+ if isinstance(m, nn.Linear) and m.bias is not None:
152
+ nn.init.constant_(m.bias, 0)
153
+ elif isinstance(m, nn.LayerNorm):
154
+ nn.init.constant_(m.bias, 0)
155
+ nn.init.constant_(m.weight, 1.0)
156
+
157
+ @torch.jit.ignore
158
+ def no_weight_decay(self):
159
+ if self.cfg.VIT.POS_EMBED == "joint":
160
+ return {'pos_embed', 'cls_token', 'st_embed'}
161
+ else:
162
+ return {'pos_embed', 'cls_token', 'temp_embed'}
163
+
164
+ def get_classifier(self):
165
+ return self.head
166
+
167
+ def reset_classifier(self, num_classes, global_pool=''):
168
+ self.num_classes = num_classes
169
+ self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
170
+
171
+ def forward_features(self, x):
172
+ # if self.video_input:
173
+ # x = x[0]
174
+ B = x.shape[0]
175
+
176
+ # Tokenize input
177
+ # if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
178
+ # for simplicity of mapping between content dimensions (input x) and token dims (after patching)
179
+ # we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
180
+
181
+ # apply patching on input
182
+ x = self.patch_embed_3d(x)
183
+ tok_mask = None
184
+
185
+ # else:
186
+ # tok_mask = None
187
+ # # 2D tokenization
188
+ # if self.video_input:
189
+ # x = x.permute(0, 2, 1, 3, 4)
190
+ # (B, T, C, H, W) = x.shape
191
+ # x = x.reshape(B * T, C, H, W)
192
+
193
+ # x = self.patch_embed(x)
194
+
195
+ # if self.video_input:
196
+ # (B2, T2, D2) = x.shape
197
+ # x = x.reshape(B, T * T2, D2)
198
+
199
+ # Append CLS token
200
+ cls_tokens = self.cls_token.expand(B, -1, -1)
201
+ x = torch.cat((cls_tokens, x), dim=1)
202
+ # if tok_mask is not None:
203
+ # # prepend 1(=keep) to the mask to account for the CLS token as well
204
+ # tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
205
+
206
+ # Interpolate positinoal embeddings
207
+ # if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
208
+ # pos_embed = self.pos_embed
209
+ # N = pos_embed.shape[1] - 1
210
+ # npatch = int((x.size(1) - 1) / self.temporal_resolution)
211
+ # class_emb = pos_embed[:, 0]
212
+ # pos_embed = pos_embed[:, 1:]
213
+ # dim = x.shape[-1]
214
+ # pos_embed = torch.nn.functional.interpolate(
215
+ # pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
216
+ # scale_factor=math.sqrt(npatch / N),
217
+ # mode='bicubic',
218
+ # )
219
+ # pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
220
+ # new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
221
+ # else:
222
+ new_pos_embed = self.pos_embed
223
+ npatch = self.patch_embed.num_patches
224
+
225
+ # Add positional embeddings to input
226
+ if self.video_input:
227
+ if self.cfg.VIT.POS_EMBED == "separate":
228
+ cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
229
+ tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
230
+ tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
231
+ total_pos_embed = tile_pos_embed + tile_temporal_embed
232
+ total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
233
+ x = x + total_pos_embed
234
+ elif self.cfg.VIT.POS_EMBED == "joint":
235
+ x = x + self.st_embed
236
+ else:
237
+ # image input
238
+ x = x + new_pos_embed
239
+
240
+ # Apply positional dropout
241
+ x = self.pos_drop(x)
242
+
243
+ # Encoding using transformer layers
244
+ for i, blk in enumerate(self.blocks):
245
+ x = blk(x,
246
+ seq_len=npatch,
247
+ num_frames=self.temporal_resolution,
248
+ approx=self.cfg.VIT.APPROX_ATTN_TYPE,
249
+ num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
250
+ tok_mask=tok_mask)
251
+
252
+ ### v-iashin: I moved it to the forward pass
253
+ # x = self.norm(x)[:, 0]
254
+ # x = self.pre_logits(x)
255
+ ###
256
+ return x, tok_mask
257
+
258
+ # def forward(self, x):
259
+ # x = self.forward_features(x)
260
+ # ### v-iashin: here. This should leave the same forward output as before
261
+ # x = self.norm(x)[:, 0]
262
+ # x = self.pre_logits(x)
263
+ # ###
264
+ # x = self.head_drop(x)
265
+ # if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
266
+ # output = []
267
+ # for head in range(len(self.num_classes)):
268
+ # x_out = getattr(self, "head%d" % head)(x)
269
+ # if not self.training:
270
+ # x_out = torch.nn.functional.softmax(x_out, dim=-1)
271
+ # output.append(x_out)
272
+ # return output
273
+ # else:
274
+ # x = self.head(x)
275
+ # if not self.training:
276
+ # x = torch.nn.functional.softmax(x, dim=-1)
277
+ # return x
data_utils/ext/synchformer/vit_helper.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3
+ # Copyright 2020 Ross Wightman
4
+ # Modified Model definition
5
+ """Video models."""
6
+
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange, repeat
12
+ from timm.layers import to_2tuple
13
+ from torch import einsum
14
+ from torch.nn import functional as F
15
+
16
+ default_cfgs = {
17
+ 'vit_1k':
18
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
19
+ 'vit_1k_large':
20
+ 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
21
+ }
22
+
23
+
24
+ def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
25
+ sim = einsum('b i d, b j d -> b i j', q, k)
26
+ # apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
27
+ if tok_mask is not None:
28
+ BSH, N = tok_mask.shape
29
+ sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0,
30
+ float('-inf')) # 1 - broadcasts across N
31
+ attn = sim.softmax(dim=-1)
32
+ out = einsum('b i j, b j d -> b i d', attn, v)
33
+ return out
34
+
35
+
36
+ class DividedAttention(nn.Module):
37
+
38
+ def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
44
+ self.proj = nn.Linear(dim, dim)
45
+
46
+ # init to zeros
47
+ self.qkv.weight.data.fill_(0)
48
+ self.qkv.bias.data.fill_(0)
49
+ self.proj.weight.data.fill_(1)
50
+ self.proj.bias.data.fill_(0)
51
+
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj_drop = nn.Dropout(proj_drop)
54
+
55
+ def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
56
+ # num of heads variable
57
+ h = self.num_heads
58
+
59
+ # project x to q, k, v vaalues
60
+ q, k, v = self.qkv(x).chunk(3, dim=-1)
61
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
62
+ if tok_mask is not None:
63
+ # replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
64
+ assert len(tok_mask.shape) == 2
65
+ tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
66
+
67
+ # Scale q
68
+ q *= self.scale
69
+
70
+ # Take out cls_q, cls_k, cls_v
71
+ (cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
72
+ # the same for masking
73
+ if tok_mask is not None:
74
+ cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
75
+ else:
76
+ cls_mask, mask_ = None, None
77
+
78
+ # let CLS token attend to key / values of all patches across time and space
79
+ cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
80
+
81
+ # rearrange across time or space
82
+ q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims),
83
+ (q_, k_, v_))
84
+
85
+ # expand CLS token keys and values across time or space and concat
86
+ r = q_.shape[0] // cls_k.shape[0]
87
+ cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
88
+
89
+ k_ = torch.cat((cls_k, k_), dim=1)
90
+ v_ = torch.cat((cls_v, v_), dim=1)
91
+
92
+ # the same for masking (if provided)
93
+ if tok_mask is not None:
94
+ # since mask does not have the latent dim (d), we need to remove it from einops dims
95
+ mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''),
96
+ **einops_dims)
97
+ cls_mask = repeat(cls_mask, 'b () -> (b r) ()',
98
+ r=r) # expand cls_mask across time or space
99
+ mask_ = torch.cat((cls_mask, mask_), dim=1)
100
+
101
+ # attention
102
+ out = qkv_attn(q_, k_, v_, tok_mask=mask_)
103
+
104
+ # merge back time or space
105
+ out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
106
+
107
+ # concat back the cls token
108
+ out = torch.cat((cls_out, out), dim=1)
109
+
110
+ # merge back the heads
111
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
112
+
113
+ ## to out
114
+ x = self.proj(out)
115
+ x = self.proj_drop(x)
116
+ return x
117
+
118
+
119
+ class DividedSpaceTimeBlock(nn.Module):
120
+
121
+ def __init__(self,
122
+ dim=768,
123
+ num_heads=12,
124
+ attn_type='divided',
125
+ mlp_ratio=4.,
126
+ qkv_bias=False,
127
+ drop=0.,
128
+ attn_drop=0.,
129
+ drop_path=0.,
130
+ act_layer=nn.GELU,
131
+ norm_layer=nn.LayerNorm):
132
+ super().__init__()
133
+
134
+ self.einops_from_space = 'b (f n) d'
135
+ self.einops_to_space = '(b f) n d'
136
+ self.einops_from_time = 'b (f n) d'
137
+ self.einops_to_time = '(b n) f d'
138
+
139
+ self.norm1 = norm_layer(dim)
140
+
141
+ self.attn = DividedAttention(dim,
142
+ num_heads=num_heads,
143
+ qkv_bias=qkv_bias,
144
+ attn_drop=attn_drop,
145
+ proj_drop=drop)
146
+
147
+ self.timeattn = DividedAttention(dim,
148
+ num_heads=num_heads,
149
+ qkv_bias=qkv_bias,
150
+ attn_drop=attn_drop,
151
+ proj_drop=drop)
152
+
153
+ # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+ self.drop_path = nn.Identity()
155
+ self.norm2 = norm_layer(dim)
156
+ mlp_hidden_dim = int(dim * mlp_ratio)
157
+ self.mlp = Mlp(in_features=dim,
158
+ hidden_features=mlp_hidden_dim,
159
+ act_layer=act_layer,
160
+ drop=drop)
161
+ self.norm3 = norm_layer(dim)
162
+
163
+ def forward(self,
164
+ x,
165
+ seq_len=196,
166
+ num_frames=8,
167
+ approx='none',
168
+ num_landmarks=128,
169
+ tok_mask: torch.Tensor = None):
170
+ time_output = self.timeattn(self.norm3(x),
171
+ self.einops_from_time,
172
+ self.einops_to_time,
173
+ n=seq_len,
174
+ tok_mask=tok_mask)
175
+ time_residual = x + time_output
176
+
177
+ space_output = self.attn(self.norm1(time_residual),
178
+ self.einops_from_space,
179
+ self.einops_to_space,
180
+ f=num_frames,
181
+ tok_mask=tok_mask)
182
+ space_residual = time_residual + self.drop_path(space_output)
183
+
184
+ x = space_residual
185
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
186
+ return x
187
+
188
+
189
+ class Mlp(nn.Module):
190
+
191
+ def __init__(self,
192
+ in_features,
193
+ hidden_features=None,
194
+ out_features=None,
195
+ act_layer=nn.GELU,
196
+ drop=0.):
197
+ super().__init__()
198
+ out_features = out_features or in_features
199
+ hidden_features = hidden_features or in_features
200
+ self.fc1 = nn.Linear(in_features, hidden_features)
201
+ self.act = act_layer()
202
+ self.fc2 = nn.Linear(hidden_features, out_features)
203
+ self.drop = nn.Dropout(drop)
204
+
205
+ def forward(self, x):
206
+ x = self.fc1(x)
207
+ x = self.act(x)
208
+ x = self.drop(x)
209
+ x = self.fc2(x)
210
+ x = self.drop(x)
211
+ return x
212
+
213
+
214
+ class PatchEmbed(nn.Module):
215
+ """ Image to Patch Embedding
216
+ """
217
+
218
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
219
+ super().__init__()
220
+ img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
221
+ patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
222
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
223
+ self.img_size = img_size
224
+ self.patch_size = patch_size
225
+ self.num_patches = num_patches
226
+
227
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
228
+
229
+ def forward(self, x):
230
+ B, C, H, W = x.shape
231
+ x = self.proj(x).flatten(2).transpose(1, 2)
232
+ return x
233
+
234
+
235
+ class PatchEmbed3D(nn.Module):
236
+ """ Image to Patch Embedding """
237
+
238
+ def __init__(self,
239
+ img_size=224,
240
+ temporal_resolution=4,
241
+ in_chans=3,
242
+ patch_size=16,
243
+ z_block_size=2,
244
+ embed_dim=768,
245
+ flatten=True):
246
+ super().__init__()
247
+ self.height = (img_size // patch_size)
248
+ self.width = (img_size // patch_size)
249
+ ### v-iashin: these two are incorrect
250
+ # self.frames = (temporal_resolution // z_block_size)
251
+ # self.num_patches = self.height * self.width * self.frames
252
+ self.z_block_size = z_block_size
253
+ ###
254
+ self.proj = nn.Conv3d(in_chans,
255
+ embed_dim,
256
+ kernel_size=(z_block_size, patch_size, patch_size),
257
+ stride=(z_block_size, patch_size, patch_size))
258
+ self.flatten = flatten
259
+
260
+ def forward(self, x):
261
+ B, C, T, H, W = x.shape
262
+ x = self.proj(x)
263
+ if self.flatten:
264
+ x = x.flatten(2).transpose(1, 2)
265
+ return x
266
+
267
+
268
+ class HeadMLP(nn.Module):
269
+
270
+ def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
271
+ super(HeadMLP, self).__init__()
272
+ self.n_input = n_input
273
+ self.n_classes = n_classes
274
+ self.n_hidden = n_hidden
275
+ if n_hidden is None:
276
+ # use linear classifier
277
+ self.block_forward = nn.Sequential(nn.Dropout(p=p),
278
+ nn.Linear(n_input, n_classes, bias=True))
279
+ else:
280
+ # use simple MLP classifier
281
+ self.block_forward = nn.Sequential(nn.Dropout(p=p),
282
+ nn.Linear(n_input, n_hidden, bias=True),
283
+ nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True),
284
+ nn.Dropout(p=p),
285
+ nn.Linear(n_hidden, n_classes, bias=True))
286
+ print(f"Dropout-NLP: {p}")
287
+
288
+ def forward(self, x):
289
+ return self.block_forward(x)
290
+
291
+
292
+ def _conv_filter(state_dict, patch_size=16):
293
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
294
+ out_dict = {}
295
+ for k, v in state_dict.items():
296
+ if 'patch_embed.proj.weight' in k:
297
+ v = v.reshape((v.shape[0], 3, patch_size, patch_size))
298
+ out_dict[k] = v
299
+ return out_dict
300
+
301
+
302
+ def adapt_input_conv(in_chans, conv_weight, agg='sum'):
303
+ conv_type = conv_weight.dtype
304
+ conv_weight = conv_weight.float()
305
+ O, I, J, K = conv_weight.shape
306
+ if in_chans == 1:
307
+ if I > 3:
308
+ assert conv_weight.shape[1] % 3 == 0
309
+ # For models with space2depth stems
310
+ conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
311
+ conv_weight = conv_weight.sum(dim=2, keepdim=False)
312
+ else:
313
+ if agg == 'sum':
314
+ print("Summing conv1 weights")
315
+ conv_weight = conv_weight.sum(dim=1, keepdim=True)
316
+ else:
317
+ print("Averaging conv1 weights")
318
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
319
+ elif in_chans != 3:
320
+ if I != 3:
321
+ raise NotImplementedError('Weight format not supported by conversion.')
322
+ else:
323
+ if agg == 'sum':
324
+ print("Summing conv1 weights")
325
+ repeat = int(math.ceil(in_chans / 3))
326
+ conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
327
+ conv_weight *= (3 / float(in_chans))
328
+ else:
329
+ print("Averaging conv1 weights")
330
+ conv_weight = conv_weight.mean(dim=1, keepdim=True)
331
+ conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
332
+ conv_weight = conv_weight.to(conv_type)
333
+ return conv_weight
334
+
335
+
336
+ def load_pretrained(model,
337
+ cfg=None,
338
+ num_classes=1000,
339
+ in_chans=3,
340
+ filter_fn=None,
341
+ strict=True,
342
+ progress=False):
343
+ # Load state dict
344
+ assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]")
345
+ state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
346
+
347
+ if filter_fn is not None:
348
+ state_dict = filter_fn(state_dict)
349
+
350
+ input_convs = 'patch_embed.proj'
351
+ if input_convs is not None and in_chans != 3:
352
+ if isinstance(input_convs, str):
353
+ input_convs = (input_convs, )
354
+ for input_conv_name in input_convs:
355
+ weight_name = input_conv_name + '.weight'
356
+ try:
357
+ state_dict[weight_name] = adapt_input_conv(in_chans,
358
+ state_dict[weight_name],
359
+ agg='avg')
360
+ print(
361
+ f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)'
362
+ )
363
+ except NotImplementedError as e:
364
+ del state_dict[weight_name]
365
+ strict = False
366
+ print(
367
+ f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
368
+ )
369
+
370
+ classifier_name = 'head'
371
+ label_offset = cfg.get('label_offset', 0)
372
+ pretrain_classes = 1000
373
+ if num_classes != pretrain_classes:
374
+ # completely discard fully connected if model num_classes doesn't match pretrained weights
375
+ del state_dict[classifier_name + '.weight']
376
+ del state_dict[classifier_name + '.bias']
377
+ strict = False
378
+ elif label_offset > 0:
379
+ # special case for pretrained weights with an extra background class in pretrained weights
380
+ classifier_weight = state_dict[classifier_name + '.weight']
381
+ state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
382
+ classifier_bias = state_dict[classifier_name + '.bias']
383
+ state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
384
+
385
+ loaded_state = state_dict
386
+ self_state = model.state_dict()
387
+ all_names = set(self_state.keys())
388
+ saved_names = set([])
389
+ for name, param in loaded_state.items():
390
+ param = param
391
+ if 'module.' in name:
392
+ name = name.replace('module.', '')
393
+ if name in self_state.keys() and param.shape == self_state[name].shape:
394
+ saved_names.add(name)
395
+ self_state[name].copy_(param)
396
+ else:
397
+ print(f"didnt load: {name} of shape: {param.shape}")
398
+ print("Missing Keys:")
399
+ print(all_names - saved_names)
data_utils/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions."""
2
+ import contextlib
3
+ import csv
4
+ import json
5
+ import os
6
+ import pathlib
7
+ import warnings
8
+
9
+ import numpy as np
10
+
11
+
12
+ def save_args(filename, args):
13
+ """Save the command-line arguments."""
14
+ args_dict = {}
15
+ for key, value in vars(args).items():
16
+ if isinstance(value, pathlib.Path):
17
+ args_dict[key] = str(value)
18
+ else:
19
+ args_dict[key] = value
20
+ save_json(filename, args_dict)
21
+
22
+
23
+ def inverse_dict(d):
24
+ """Return the inverse dictionary."""
25
+ return {v: k for k, v in d.items()}
26
+
27
+
28
+ def save_txt(filename, data):
29
+ """Save a list to a TXT file."""
30
+ with open(filename, "w", encoding="utf8") as f:
31
+ for item in data:
32
+ f.write(f"{item}\n")
33
+
34
+
35
+ def load_txt(filename):
36
+ """Load a TXT file as a list."""
37
+ with open(filename, encoding="utf8") as f:
38
+ return [line.strip() for line in f]
39
+
40
+
41
+ def save_json(filename, data):
42
+ """Save data as a JSON file."""
43
+ with open(filename, "w", encoding="utf8") as f:
44
+ json.dump(data, f)
45
+
46
+
47
+ def load_json(filename):
48
+ """Load data from a JSON file."""
49
+ with open(filename, encoding="utf8") as f:
50
+ return json.load(f)
51
+
52
+
53
+ def save_csv(filename, data, header=""):
54
+ """Save data as a CSV file."""
55
+ np.savetxt(
56
+ filename, data, fmt="%d", delimiter=",", header=header, comments=""
57
+ )
58
+
59
+
60
+ def load_csv(filename, skiprows=1):
61
+ """Load data from a CSV file."""
62
+ return np.loadtxt(filename, dtype=int, delimiter=",", skiprows=skiprows)
63
+
64
+
65
+ def load_csv_text(filename, headerless=True):
66
+ """Read a CSV file into a list of dictionaries or lists."""
67
+ with open(filename) as f:
68
+ if headerless:
69
+ return [row for row in csv.reader(f)]
70
+ reader = csv.DictReader(f)
71
+ return [
72
+ {field: row[field] for field in reader.fieldnames}
73
+ for row in reader
74
+ ]
75
+
76
+
77
+ def ignore_exceptions(func):
78
+ """Decorator that ignores all errors and warnings."""
79
+
80
+ def inner(*args, **kwargs):
81
+ with warnings.catch_warnings():
82
+ warnings.simplefilter("ignore")
83
+ try:
84
+ return func(*args, **kwargs)
85
+ except Exception:
86
+ return None
87
+
88
+ return inner
89
+
90
+
91
+ def suppress_outputs(func):
92
+ """Decorator that suppresses writing to stdout and stderr."""
93
+
94
+ def inner(*args, **kwargs):
95
+ devnull = open(os.devnull, "w")
96
+ with contextlib.redirect_stdout(devnull):
97
+ with contextlib.redirect_stderr(devnull):
98
+ return func(*args, **kwargs)
99
+
100
+ return inner
101
+
102
+
103
+ def resolve_paths(func):
104
+ """Decorator that resolves all paths."""
105
+
106
+ def inner(*args, **kwargs):
107
+ parsed = func(*args, **kwargs)
108
+ for key in vars(parsed).keys():
109
+ if isinstance(getattr(parsed, key), pathlib.Path):
110
+ setattr(
111
+ parsed, key, getattr(parsed, key).expanduser().resolve()
112
+ )
113
+ return parsed
114
+
115
+ return inner
data_utils/v2a_utils/__init__.py ADDED
File without changes
data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc ADDED
Binary file (4.06 kB). View file
 
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc ADDED
Binary file (4.09 kB). View file
 
data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc ADDED
Binary file (6.64 kB). View file
 
data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc ADDED
Binary file (5.84 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc ADDED
Binary file (5.23 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc ADDED
Binary file (6.59 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc ADDED
Binary file (5.94 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc ADDED
Binary file (5.95 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc ADDED
Binary file (4.4 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc ADDED
Binary file (4.49 kB). View file
 
data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc ADDED
Binary file (4.75 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc ADDED
Binary file (4.99 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc ADDED
Binary file (5.18 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc ADDED
Binary file (6.56 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc ADDED
Binary file (6.5 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc ADDED
Binary file (5.64 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc ADDED
Binary file (5.14 kB). View file
 
data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc ADDED
Binary file (2.43 kB). View file
 
data_utils/v2a_utils/feature_utils_224.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+ import json
3
+ import open_clip
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from open_clip import create_model_from_pretrained
9
+ from torchvision.transforms import Normalize
10
+ from think_sound.models.factory import create_model_from_config
11
+ from think_sound.models.utils import load_ckpt_state_dict
12
+ from think_sound.training.utils import copy_state_dict
13
+ from transformers import AutoModel
14
+ from transformers import AutoProcessor
15
+ from transformers import T5EncoderModel, AutoTokenizer
16
+ import logging
17
+ from data_utils.ext.synchformer import Synchformer
18
+
19
+ log = logging.getLogger()
20
+
21
+ def patch_clip(clip_model):
22
+ # a hack to make it output last hidden states
23
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
24
+ def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
25
+ output_attentions: Optional[bool] = None,
26
+ output_hidden_states: Optional[bool] = None,
27
+ return_dict: Optional[bool] = None):
28
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
29
+ output_hidden_states = (
30
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
31
+ )
32
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
33
+
34
+ text_outputs = self.text_model(
35
+ input_ids=input_ids,
36
+ attention_mask=attention_mask,
37
+ position_ids=position_ids,
38
+ output_attentions=output_attentions,
39
+ output_hidden_states=output_hidden_states,
40
+ return_dict=return_dict,
41
+ )
42
+ last_hidden_state = text_outputs[0]
43
+ pooled_output = text_outputs[1]
44
+ text_features = self.text_projection(pooled_output)
45
+
46
+ return text_features, last_hidden_state
47
+
48
+ clip_model.get_text_features = new_get_text_features.__get__(clip_model)
49
+ return clip_model
50
+
51
+
52
+ class FeaturesUtils(nn.Module):
53
+
54
+ def __init__(
55
+ self,
56
+ *,
57
+ vae_ckpt: Optional[str] = None,
58
+ vae_config: Optional[str] = None,
59
+ synchformer_ckpt: Optional[str] = None,
60
+ enable_conditions: bool = True,
61
+ need_vae_encoder: bool = True,
62
+ ):
63
+ super().__init__()
64
+
65
+ if enable_conditions:
66
+ self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
67
+ self.clip_model = patch_clip(self.clip_model)
68
+ self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
69
+ self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
70
+ self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
71
+ # self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
72
+ # std=[0.26862954, 0.26130258, 0.27577711])
73
+ self.synchformer = Synchformer()
74
+ self.synchformer.load_state_dict(
75
+ torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
76
+
77
+ # self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
78
+ else:
79
+ self.clip_model = None
80
+ self.synchformer = None
81
+ self.tokenizer = None
82
+
83
+ if vae_ckpt is not None:
84
+ with open(vae_config) as f:
85
+ vae_config = json.load(f)
86
+ self.vae = create_model_from_config(vae_config)
87
+ print(f"Loading model checkpoint from {vae_ckpt}")
88
+ # Load checkpoint
89
+ copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
90
+ else:
91
+ self.tod = None
92
+
93
+ def compile(self):
94
+ if self.clip_model is not None:
95
+ self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
96
+ self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
97
+ if self.synchformer is not None:
98
+ self.synchformer = torch.compile(self.synchformer)
99
+
100
+
101
+ def train(self, mode: bool) -> None:
102
+ return super().train(False)
103
+
104
+ @torch.inference_mode()
105
+ def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
106
+ assert self.clip_model is not None, 'CLIP is not loaded'
107
+ # x: (B, T, C, H, W) H/W: 384
108
+ b, t, c, h, w = x.shape
109
+
110
+ assert c == 3 and h == 224 and w == 224
111
+ # x = self.clip_preprocess(x)
112
+ x = rearrange(x, 'b t c h w -> (b t) c h w')
113
+ outputs = []
114
+ if batch_size < 0:
115
+ batch_size = b * t
116
+ for i in range(0, b * t, batch_size):
117
+ outputs.append(self.clip_model.get_image_features(x[i:i + batch_size]))
118
+ x = torch.cat(outputs, dim=0)
119
+ # x = self.clip_model.encode_image(x, normalize=True)
120
+ x = rearrange(x, '(b t) d -> b t d', b=b)
121
+ return x
122
+
123
+ @torch.inference_mode()
124
+ def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
125
+ assert self.synchformer is not None, 'Synchformer is not loaded'
126
+ # x: (B, T, C, H, W) H/W: 384
127
+ b, t, c, h, w = x.shape
128
+ # import ipdb
129
+ # ipdb.set_trace()
130
+ assert c == 3 and h == 224 and w == 224
131
+
132
+ # partition the video
133
+ segment_size = 16
134
+ step_size = 8
135
+ num_segments = (t - segment_size) // step_size + 1
136
+ segments = []
137
+ for i in range(num_segments):
138
+ segments.append(x[:, i * step_size:i * step_size + segment_size])
139
+ x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
140
+
141
+ outputs = []
142
+ if batch_size < 0:
143
+ batch_size = b
144
+ x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
145
+ for i in range(0, b * num_segments, batch_size):
146
+ outputs.append(self.synchformer(x[i:i + batch_size]))
147
+ x = torch.cat(outputs, dim=0)
148
+ x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
149
+ return x
150
+
151
+ @torch.inference_mode()
152
+ def encode_text(self, text: list[str]) -> torch.Tensor:
153
+ assert self.clip_model is not None, 'CLIP is not loaded'
154
+ # assert self.tokenizer is not None, 'Tokenizer is not loaded'
155
+ # x: (B, L)
156
+ tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
157
+ return self.clip_model.get_text_features(**tokens)
158
+
159
+ @torch.inference_mode()
160
+ def encode_t5_text(self, text: list[str]) -> torch.Tensor:
161
+ assert self.t5_model is not None, 'T5 model is not loaded'
162
+ assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
163
+ # x: (B, L)
164
+ inputs = self.t5_tokenizer(text,
165
+ truncation=True,
166
+ max_length=77,
167
+ padding="max_length",
168
+ return_tensors="pt").to(self.device)
169
+ return self.t5_model(**inputs).last_hidden_state
170
+
171
+ @torch.inference_mode()
172
+ def encode_audio(self, x) -> torch.Tensor:
173
+ x = self.vae.encode(x)
174
+ return x
175
+
176
+ @property
177
+ def device(self):
178
+ return next(self.parameters()).device
179
+
180
+ @property
181
+ def dtype(self):
182
+ return next(self.parameters()).dtype