Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
08f69f6
1
Parent(s):
a1f4877
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +4 -5
- app.py +331 -0
- data_utils/__init__.py +0 -0
- data_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- data_utils/__pycache__/utils.cpython-310.pyc +0 -0
- data_utils/__pycache__/utils.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/LICENSE +21 -0
- data_utils/ext/synchformer/__init__.py +1 -0
- data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc +0 -0
- data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc +0 -0
- data_utils/ext/synchformer/divided_224_16x4.yaml +84 -0
- data_utils/ext/synchformer/motionformer.py +400 -0
- data_utils/ext/synchformer/synchformer.py +55 -0
- data_utils/ext/synchformer/utils.py +92 -0
- data_utils/ext/synchformer/video_model_builder.py +277 -0
- data_utils/ext/synchformer/vit_helper.py +399 -0
- data_utils/utils.py +115 -0
- data_utils/v2a_utils/__init__.py +0 -0
- data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc +0 -0
- data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc +0 -0
- data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc +0 -0
- data_utils/v2a_utils/feature_utils_224.py +182 -0
README.md
CHANGED
@@ -1,14 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.35.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
-
short_description: 'demo of ThinkSound '
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Test
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: gray
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.35.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: mit
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from prefigure.prefigure import get_all_args, push_wandb_config
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
os.environ["GRADIO_TEMP_DIR"] = "./.gradio_tmp"
|
5 |
+
import re
|
6 |
+
import torch
|
7 |
+
import torchaudio
|
8 |
+
# import pytorch_lightning as pl
|
9 |
+
import lightning as L
|
10 |
+
from lightning.pytorch.callbacks import Timer, ModelCheckpoint, BasePredictionWriter
|
11 |
+
from lightning.pytorch.callbacks import Callback
|
12 |
+
from lightning.pytorch.tuner import Tuner
|
13 |
+
from lightning.pytorch import seed_everything
|
14 |
+
import random
|
15 |
+
from datetime import datetime
|
16 |
+
# from think_sound.data.dataset import create_dataloader_from_config
|
17 |
+
from think_sound.data.datamodule import DataModule
|
18 |
+
from think_sound.models import create_model_from_config
|
19 |
+
from think_sound.models.utils import load_ckpt_state_dict, remove_weight_norm_from_model
|
20 |
+
from think_sound.training import create_training_wrapper_from_config, create_demo_callback_from_config
|
21 |
+
from think_sound.training.utils import copy_state_dict
|
22 |
+
from think_sound.inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
|
23 |
+
from data_utils.v2a_utils.feature_utils_224 import FeaturesUtils
|
24 |
+
from torch.utils.data import Dataset
|
25 |
+
from typing import Optional, Union
|
26 |
+
from torchvision.transforms import v2
|
27 |
+
from torio.io import StreamingMediaDecoder
|
28 |
+
from torchvision.utils import save_image
|
29 |
+
from transformers import AutoProcessor
|
30 |
+
import torch.nn.functional as F
|
31 |
+
import gradio as gr
|
32 |
+
import tempfile
|
33 |
+
import subprocess
|
34 |
+
from huggingface_hub import hf_hub_download
|
35 |
+
|
36 |
+
_CLIP_SIZE = 224
|
37 |
+
_CLIP_FPS = 8.0
|
38 |
+
|
39 |
+
_SYNC_SIZE = 224
|
40 |
+
_SYNC_FPS = 25.0
|
41 |
+
|
42 |
+
def pad_to_square(video_tensor):
|
43 |
+
if len(video_tensor.shape) != 4:
|
44 |
+
raise ValueError("Input tensor must have shape (l, c, h, w)")
|
45 |
+
|
46 |
+
l, c, h, w = video_tensor.shape
|
47 |
+
max_side = max(h, w)
|
48 |
+
|
49 |
+
pad_h = max_side - h
|
50 |
+
pad_w = max_side - w
|
51 |
+
|
52 |
+
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
53 |
+
|
54 |
+
video_padded = F.pad(video_tensor, pad=padding, mode='constant', value=0)
|
55 |
+
|
56 |
+
return video_padded
|
57 |
+
|
58 |
+
|
59 |
+
class VGGSound(Dataset):
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
sample_rate: int = 44_100,
|
64 |
+
duration_sec: float = 9.0,
|
65 |
+
audio_samples: Optional[int] = 397312,
|
66 |
+
normalize_audio: bool = False,
|
67 |
+
):
|
68 |
+
if audio_samples is None:
|
69 |
+
self.audio_samples = int(sample_rate * duration_sec)
|
70 |
+
else:
|
71 |
+
self.audio_samples = audio_samples
|
72 |
+
effective_duration = audio_samples / sample_rate
|
73 |
+
# make sure the duration is close enough, within 15ms
|
74 |
+
assert abs(effective_duration - duration_sec) < 0.015, \
|
75 |
+
f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
|
76 |
+
|
77 |
+
self.sample_rate = sample_rate
|
78 |
+
self.duration_sec = duration_sec
|
79 |
+
|
80 |
+
self.expected_audio_length = self.audio_samples
|
81 |
+
self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
|
82 |
+
self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
|
83 |
+
|
84 |
+
self.clip_transform = v2.Compose([
|
85 |
+
v2.Lambda(pad_to_square), # 先填充为正方形
|
86 |
+
v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
|
87 |
+
v2.ToImage(),
|
88 |
+
v2.ToDtype(torch.float32, scale=True),
|
89 |
+
])
|
90 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
91 |
+
self.sync_transform = v2.Compose([
|
92 |
+
v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
|
93 |
+
v2.CenterCrop(_SYNC_SIZE),
|
94 |
+
v2.ToImage(),
|
95 |
+
v2.ToDtype(torch.float32, scale=True),
|
96 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
97 |
+
])
|
98 |
+
|
99 |
+
self.resampler = {}
|
100 |
+
|
101 |
+
def sample(self, video_path,label):
|
102 |
+
video_id = video_path
|
103 |
+
|
104 |
+
reader = StreamingMediaDecoder(video_path)
|
105 |
+
reader.add_basic_video_stream(
|
106 |
+
frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
|
107 |
+
frame_rate=_CLIP_FPS,
|
108 |
+
format='rgb24',
|
109 |
+
)
|
110 |
+
reader.add_basic_video_stream(
|
111 |
+
frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
|
112 |
+
frame_rate=_SYNC_FPS,
|
113 |
+
format='rgb24',
|
114 |
+
)
|
115 |
+
|
116 |
+
reader.fill_buffer()
|
117 |
+
data_chunk = reader.pop_chunks()
|
118 |
+
|
119 |
+
clip_chunk = data_chunk[0]
|
120 |
+
sync_chunk = data_chunk[1]
|
121 |
+
|
122 |
+
if sync_chunk is None:
|
123 |
+
raise RuntimeError(f'Sync video returned None {video_id}')
|
124 |
+
|
125 |
+
clip_chunk = clip_chunk[:self.clip_expected_length]
|
126 |
+
# import ipdb
|
127 |
+
# ipdb.set_trace()
|
128 |
+
if clip_chunk.shape[0] != self.clip_expected_length:
|
129 |
+
current_length = clip_chunk.shape[0]
|
130 |
+
padding_needed = self.clip_expected_length - current_length
|
131 |
+
|
132 |
+
# Check that padding needed is no more than 2
|
133 |
+
assert padding_needed < 4, f'Padding no more than 2 frames allowed, but {padding_needed} needed'
|
134 |
+
|
135 |
+
# If assertion passes, proceed with padding
|
136 |
+
if padding_needed > 0:
|
137 |
+
last_frame = clip_chunk[-1]
|
138 |
+
log.info(last_frame.shape)
|
139 |
+
# Repeat the last frame to reach the expected length
|
140 |
+
padding = last_frame.repeat(padding_needed, 1, 1, 1)
|
141 |
+
clip_chunk = torch.cat((clip_chunk, padding), dim=0)
|
142 |
+
# raise RuntimeError(f'CLIP video wrong length {video_id}, '
|
143 |
+
# f'expected {self.clip_expected_length}, '
|
144 |
+
# f'got {clip_chunk.shape[0]}')
|
145 |
+
|
146 |
+
# save_image(clip_chunk[0] / 255.0,'ori.png')
|
147 |
+
clip_chunk = pad_to_square(clip_chunk)
|
148 |
+
|
149 |
+
clip_chunk = self.clip_processor(images=clip_chunk, return_tensors="pt")["pixel_values"]
|
150 |
+
|
151 |
+
sync_chunk = sync_chunk[:self.sync_expected_length]
|
152 |
+
if sync_chunk.shape[0] != self.sync_expected_length:
|
153 |
+
# padding using the last frame, but no more than 2
|
154 |
+
current_length = sync_chunk.shape[0]
|
155 |
+
last_frame = sync_chunk[-1]
|
156 |
+
# 重复最后一帧以进行填充
|
157 |
+
padding = last_frame.repeat(self.sync_expected_length - current_length, 1, 1, 1)
|
158 |
+
assert self.sync_expected_length - current_length < 12, f'sync can pad no more than 2 while {self.sync_expected_length - current_length}'
|
159 |
+
sync_chunk = torch.cat((sync_chunk, padding), dim=0)
|
160 |
+
# raise RuntimeError(f'Sync video wrong length {video_id}, '
|
161 |
+
# f'expected {self.sync_expected_length}, '
|
162 |
+
# f'got {sync_chunk.shape[0]}')
|
163 |
+
|
164 |
+
sync_chunk = self.sync_transform(sync_chunk)
|
165 |
+
# assert audio_chunk.shape[1] == self.expected_audio_length and clip_chunk.shape[0] == self.clip_expected_length \
|
166 |
+
# and sync_chunk.shape[0] == self.sync_expected_length, 'error processed data shape'
|
167 |
+
data = {
|
168 |
+
'id': video_id,
|
169 |
+
'caption': label,
|
170 |
+
# 'audio': audio_chunk,
|
171 |
+
'clip_video': clip_chunk,
|
172 |
+
'sync_video': sync_chunk,
|
173 |
+
}
|
174 |
+
|
175 |
+
return data
|
176 |
+
|
177 |
+
# 检查设备
|
178 |
+
if torch.cuda.is_available():
|
179 |
+
device = 'cuda'
|
180 |
+
extra_device = 'cuda:1' if torch.cuda.device_count() > 1 else 'cuda:0'
|
181 |
+
else:
|
182 |
+
device = 'cpu'
|
183 |
+
extra_device = 'cpu'
|
184 |
+
|
185 |
+
vae_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=3-step=100000.ckpt",repo_type="dataset")
|
186 |
+
synchformer_ckpt = hf_hub_download(repo_id="UncleWang233/occdata", filename="synchformer_state_dict.pth",repo_type="dataset")
|
187 |
+
feature_extractor = FeaturesUtils(
|
188 |
+
vae_ckpt=vae_ckpt,
|
189 |
+
vae_config='think_sound/configs/model_configs/autoencoders/stable_audio_2_0_vae.json',
|
190 |
+
enable_conditions=True,
|
191 |
+
synchformer_ckpt=synchformer_ckpt
|
192 |
+
).eval().to(extra_device)
|
193 |
+
|
194 |
+
preprocesser = VGGSound()
|
195 |
+
|
196 |
+
args = get_all_args()
|
197 |
+
|
198 |
+
seed = 10086
|
199 |
+
|
200 |
+
seed_everything(seed, workers=True)
|
201 |
+
|
202 |
+
|
203 |
+
#Get JSON config from args.model_config
|
204 |
+
with open("think_sound/configs/model_configs/vt2audio/latent_clip_224_text_sync_mmdit_flow_logit_t5_kernel_size3.json") as f:
|
205 |
+
model_config = json.load(f)
|
206 |
+
|
207 |
+
model = create_model_from_config(model_config)
|
208 |
+
|
209 |
+
## speed by torch.compile
|
210 |
+
if args.compile:
|
211 |
+
model = torch.compile(model)
|
212 |
+
|
213 |
+
if args.pretrained_ckpt_path:
|
214 |
+
copy_state_dict(model, load_ckpt_state_dict(args.pretrained_ckpt_path,prefix='diffusion.')) # autoencoder. diffusion.
|
215 |
+
|
216 |
+
if args.remove_pretransform_weight_norm == "pre_load":
|
217 |
+
remove_weight_norm_from_model(model.pretransform)
|
218 |
+
|
219 |
+
|
220 |
+
load_vae_state = load_ckpt_state_dict(vae_ckpt, prefix='autoencoder.')
|
221 |
+
# new_state_dict = {k.replace("autoencoder.", ""): v for k, v in load_vae_state.items() if k.startswith("autoencoder.")}
|
222 |
+
model.pretransform.load_state_dict(load_vae_state)
|
223 |
+
|
224 |
+
# Remove weight_norm from the pretransform if specified
|
225 |
+
if args.remove_pretransform_weight_norm == "post_load":
|
226 |
+
remove_weight_norm_from_model(model.pretransform)
|
227 |
+
ckpt_path = hf_hub_download(repo_id="UncleWang233/occdata", filename="epoch=10-step=68000.ckpt",repo_type="dataset")
|
228 |
+
training_wrapper = create_training_wrapper_from_config(model_config, model)
|
229 |
+
# 加载模型权重时根据设备选择map_location
|
230 |
+
if device == 'cuda':
|
231 |
+
training_wrapper.load_state_dict(torch.load(ckpt_path)['state_dict'])
|
232 |
+
else:
|
233 |
+
training_wrapper.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))['state_dict'])
|
234 |
+
|
235 |
+
def get_audio(video_path, caption):
|
236 |
+
# 允许caption为空
|
237 |
+
if caption is None:
|
238 |
+
caption = ''
|
239 |
+
timer = Timer(duration="00:15:00:00")
|
240 |
+
data = preprocesser.sample(video_path, caption)
|
241 |
+
|
242 |
+
preprocessed_data = {}
|
243 |
+
metaclip_global_text_features, metaclip_text_features = feature_extractor.encode_text(data['caption'])
|
244 |
+
preprocessed_data['metaclip_global_text_features'] = metaclip_global_text_features.detach().cpu().squeeze(0)
|
245 |
+
preprocessed_data['metaclip_text_features'] = metaclip_text_features.detach().cpu().squeeze(0)
|
246 |
+
|
247 |
+
t5_features = feature_extractor.encode_t5_text(data['caption'])
|
248 |
+
preprocessed_data['t5_features'] = t5_features.detach().cpu().squeeze(0)
|
249 |
+
|
250 |
+
clip_features = feature_extractor.encode_video_with_clip(data['clip_video'].unsqueeze(0).to(extra_device))
|
251 |
+
preprocessed_data['metaclip_features'] = clip_features.detach().cpu().squeeze(0)
|
252 |
+
|
253 |
+
sync_features = feature_extractor.encode_video_with_sync(data['sync_video'].unsqueeze(0).to(extra_device))
|
254 |
+
preprocessed_data['sync_features'] = sync_features.detach().cpu().squeeze(0)
|
255 |
+
preprocessed_data['video_exist'] = torch.tensor(True)
|
256 |
+
|
257 |
+
metadata = [preprocessed_data]
|
258 |
+
|
259 |
+
batch_size = 1
|
260 |
+
length = 194
|
261 |
+
with torch.amp.autocast(device):
|
262 |
+
conditioning = training_wrapper.diffusion.conditioner(metadata, training_wrapper.device)
|
263 |
+
|
264 |
+
video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
|
265 |
+
conditioning['metaclip_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_clip_feat
|
266 |
+
conditioning['sync_features'][~video_exist] = training_wrapper.diffusion.model.model.empty_sync_feat
|
267 |
+
|
268 |
+
cond_inputs = training_wrapper.diffusion.get_conditioning_inputs(conditioning)
|
269 |
+
noise = torch.randn([batch_size, training_wrapper.diffusion.io_channels, length]).to(training_wrapper.device)
|
270 |
+
with torch.amp.autocast(device):
|
271 |
+
model = training_wrapper.diffusion.model
|
272 |
+
if training_wrapper.diffusion_objective == "v":
|
273 |
+
fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
274 |
+
elif training_wrapper.diffusion_objective == "rectified_flow":
|
275 |
+
import time
|
276 |
+
start_time = time.time()
|
277 |
+
fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
|
278 |
+
end_time = time.time()
|
279 |
+
execution_time = end_time - start_time
|
280 |
+
print(f"执行时间: {execution_time:.2f} 秒")
|
281 |
+
if training_wrapper.diffusion.pretransform is not None:
|
282 |
+
fakes = training_wrapper.diffusion.pretransform.decode(fakes)
|
283 |
+
|
284 |
+
audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
285 |
+
# 保存临时音频文件
|
286 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_audio:
|
287 |
+
torchaudio.save(tmp_audio.name, audios[0], 44100)
|
288 |
+
audio_path = tmp_audio.name
|
289 |
+
return audio_path
|
290 |
+
|
291 |
+
# 合成新视频:用ffmpeg将音频与原视频合成
|
292 |
+
|
293 |
+
def synthesize_video_with_audio(video_file, caption):
|
294 |
+
# 允许caption为空
|
295 |
+
if caption is None:
|
296 |
+
caption = ''
|
297 |
+
audio_path = get_audio(video_file, caption)
|
298 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_video:
|
299 |
+
output_video_path = tmp_video.name
|
300 |
+
# ffmpeg命令:用新音频替换原视频音轨
|
301 |
+
cmd = [
|
302 |
+
'ffmpeg', '-y', '-i', video_file, '-i', audio_path,
|
303 |
+
'-c:v', 'copy', '-map', '0:v:0', '-map', '1:a:0',
|
304 |
+
'-shortest', output_video_path
|
305 |
+
]
|
306 |
+
subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
307 |
+
return output_video_path
|
308 |
+
|
309 |
+
# Gradio界面
|
310 |
+
with gr.Blocks() as demo:
|
311 |
+
gr.Markdown("# ThinkSound\nupload video and caption(optional), and get video with audio!")
|
312 |
+
with gr.Row():
|
313 |
+
video_input = gr.Video(label="upload video")
|
314 |
+
caption_input = gr.Textbox(label="caption(optional)", placeholder="can be empty", lines=1)
|
315 |
+
output_video = gr.Video(label="output video")
|
316 |
+
btn = gr.Button("start synthesize")
|
317 |
+
btn.click(fn=synthesize_video_with_audio, inputs=[video_input, caption_input], outputs=output_video)
|
318 |
+
|
319 |
+
gr.Examples(
|
320 |
+
examples=[
|
321 |
+
["./examples/1_mute.mp4", "Playing Trumpet"],
|
322 |
+
["./examples/2_mute.mp4", "Axe striking"],
|
323 |
+
["./examples/3_mute.mp4", "Gentle Sucking Sounds From the Pacifier"],
|
324 |
+
["./examples/4_mute.mp4", "train passing by"],
|
325 |
+
["./examples/5_mute.mp4", "Lighting Firecrackers"]
|
326 |
+
],
|
327 |
+
inputs=[video_input, caption_input],
|
328 |
+
)
|
329 |
+
|
330 |
+
demo.launch(share=True)
|
331 |
+
|
data_utils/__init__.py
ADDED
File without changes
|
data_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (149 Bytes). View file
|
|
data_utils/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (4.56 kB). View file
|
|
data_utils/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (4.56 kB). View file
|
|
data_utils/ext/synchformer/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Vladimir Iashin
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
data_utils/ext/synchformer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from data_utils.ext.synchformer.synchformer import Synchformer
|
data_utils/ext/synchformer/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (243 Bytes). View file
|
|
data_utils/ext/synchformer/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (241 Bytes). View file
|
|
data_utils/ext/synchformer/__pycache__/motionformer.cpython-310.pyc
ADDED
Binary file (12.7 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/motionformer.cpython-39.pyc
ADDED
Binary file (12.7 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/synchformer.cpython-310.pyc
ADDED
Binary file (1.91 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/synchformer.cpython-39.pyc
ADDED
Binary file (1.9 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (3.97 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.78 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-310.pyc
ADDED
Binary file (5.84 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/video_model_builder.cpython-39.pyc
ADDED
Binary file (5.8 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/vit_helper.cpython-310.pyc
ADDED
Binary file (10.6 kB). View file
|
|
data_utils/ext/synchformer/__pycache__/vit_helper.cpython-39.pyc
ADDED
Binary file (10.6 kB). View file
|
|
data_utils/ext/synchformer/divided_224_16x4.yaml
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
TRAIN:
|
2 |
+
ENABLE: True
|
3 |
+
DATASET: Ssv2
|
4 |
+
BATCH_SIZE: 32
|
5 |
+
EVAL_PERIOD: 5
|
6 |
+
CHECKPOINT_PERIOD: 5
|
7 |
+
AUTO_RESUME: True
|
8 |
+
CHECKPOINT_EPOCH_RESET: True
|
9 |
+
CHECKPOINT_FILE_PATH: /checkpoint/fmetze/neurips_sota/40944587/checkpoints/checkpoint_epoch_00035.pyth
|
10 |
+
DATA:
|
11 |
+
NUM_FRAMES: 16
|
12 |
+
SAMPLING_RATE: 4
|
13 |
+
TRAIN_JITTER_SCALES: [256, 320]
|
14 |
+
TRAIN_CROP_SIZE: 224
|
15 |
+
TEST_CROP_SIZE: 224
|
16 |
+
INPUT_CHANNEL_NUM: [3]
|
17 |
+
MEAN: [0.5, 0.5, 0.5]
|
18 |
+
STD: [0.5, 0.5, 0.5]
|
19 |
+
PATH_TO_DATA_DIR: /private/home/mandelapatrick/slowfast/data/ssv2
|
20 |
+
PATH_PREFIX: /datasets01/SomethingV2/092720/20bn-something-something-v2-frames
|
21 |
+
INV_UNIFORM_SAMPLE: True
|
22 |
+
RANDOM_FLIP: False
|
23 |
+
REVERSE_INPUT_CHANNEL: True
|
24 |
+
USE_RAND_AUGMENT: True
|
25 |
+
RE_PROB: 0.0
|
26 |
+
USE_REPEATED_AUG: False
|
27 |
+
USE_RANDOM_RESIZE_CROPS: False
|
28 |
+
COLORJITTER: False
|
29 |
+
GRAYSCALE: False
|
30 |
+
GAUSSIAN: False
|
31 |
+
SOLVER:
|
32 |
+
BASE_LR: 1e-4
|
33 |
+
LR_POLICY: steps_with_relative_lrs
|
34 |
+
LRS: [1, 0.1, 0.01]
|
35 |
+
STEPS: [0, 20, 30]
|
36 |
+
MAX_EPOCH: 35
|
37 |
+
MOMENTUM: 0.9
|
38 |
+
WEIGHT_DECAY: 5e-2
|
39 |
+
WARMUP_EPOCHS: 0.0
|
40 |
+
OPTIMIZING_METHOD: adamw
|
41 |
+
USE_MIXED_PRECISION: True
|
42 |
+
SMOOTHING: 0.2
|
43 |
+
SLOWFAST:
|
44 |
+
ALPHA: 8
|
45 |
+
VIT:
|
46 |
+
PATCH_SIZE: 16
|
47 |
+
PATCH_SIZE_TEMP: 2
|
48 |
+
CHANNELS: 3
|
49 |
+
EMBED_DIM: 768
|
50 |
+
DEPTH: 12
|
51 |
+
NUM_HEADS: 12
|
52 |
+
MLP_RATIO: 4
|
53 |
+
QKV_BIAS: True
|
54 |
+
VIDEO_INPUT: True
|
55 |
+
TEMPORAL_RESOLUTION: 8
|
56 |
+
USE_MLP: True
|
57 |
+
DROP: 0.0
|
58 |
+
POS_DROPOUT: 0.0
|
59 |
+
DROP_PATH: 0.2
|
60 |
+
IM_PRETRAINED: True
|
61 |
+
HEAD_DROPOUT: 0.0
|
62 |
+
HEAD_ACT: tanh
|
63 |
+
PRETRAINED_WEIGHTS: vit_1k
|
64 |
+
ATTN_LAYER: divided
|
65 |
+
MODEL:
|
66 |
+
NUM_CLASSES: 174
|
67 |
+
ARCH: slow
|
68 |
+
MODEL_NAME: VisionTransformer
|
69 |
+
LOSS_FUNC: cross_entropy
|
70 |
+
TEST:
|
71 |
+
ENABLE: True
|
72 |
+
DATASET: Ssv2
|
73 |
+
BATCH_SIZE: 64
|
74 |
+
NUM_ENSEMBLE_VIEWS: 1
|
75 |
+
NUM_SPATIAL_CROPS: 3
|
76 |
+
DATA_LOADER:
|
77 |
+
NUM_WORKERS: 4
|
78 |
+
PIN_MEMORY: True
|
79 |
+
NUM_GPUS: 8
|
80 |
+
NUM_SHARDS: 4
|
81 |
+
RNG_SEED: 0
|
82 |
+
OUTPUT_DIR: .
|
83 |
+
TENSORBOARD:
|
84 |
+
ENABLE: True
|
data_utils/ext/synchformer/motionformer.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import einops
|
5 |
+
import torch
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from timm.layers import trunc_normal_
|
8 |
+
from torch import nn
|
9 |
+
|
10 |
+
from data_utils.ext.synchformer.utils import check_if_file_exists_else_download
|
11 |
+
from data_utils.ext.synchformer.video_model_builder import VisionTransformer
|
12 |
+
|
13 |
+
FILE2URL = {
|
14 |
+
# cfg
|
15 |
+
'motionformer_224_16x4.yaml':
|
16 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/motionformer_224_16x4.yaml',
|
17 |
+
'joint_224_16x4.yaml':
|
18 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/joint_224_16x4.yaml',
|
19 |
+
'divided_224_16x4.yaml':
|
20 |
+
'https://raw.githubusercontent.com/facebookresearch/Motionformer/bf43d50/configs/SSV2/divided_224_16x4.yaml',
|
21 |
+
# ckpt
|
22 |
+
'ssv2_motionformer_224_16x4.pyth':
|
23 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_motionformer_224_16x4.pyth',
|
24 |
+
'ssv2_joint_224_16x4.pyth':
|
25 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_joint_224_16x4.pyth',
|
26 |
+
'ssv2_divided_224_16x4.pyth':
|
27 |
+
'https://dl.fbaipublicfiles.com/motionformer/ssv2_divided_224_16x4.pyth',
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
class MotionFormer(VisionTransformer):
|
32 |
+
''' This class serves three puposes:
|
33 |
+
1. Renames the class to MotionFormer.
|
34 |
+
2. Downloads the cfg from the original repo and patches it if needed.
|
35 |
+
3. Takes care of feature extraction by redefining .forward()
|
36 |
+
- if `extract_features=True` and `factorize_space_time=False`,
|
37 |
+
the output is of shape (B, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
38 |
+
- if `extract_features=True` and `factorize_space_time=True`, the output is of shape (B*S, D)
|
39 |
+
and spatial and temporal transformer encoder layers are used.
|
40 |
+
- if `extract_features=True` and `factorize_space_time=True` as well as `add_global_repr=True`
|
41 |
+
the output is of shape (B, D) and spatial and temporal transformer encoder layers
|
42 |
+
are used as well as the global representation is extracted from segments (extra pos emb
|
43 |
+
is added).
|
44 |
+
'''
|
45 |
+
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
extract_features: bool = False,
|
49 |
+
ckpt_path: str = None,
|
50 |
+
factorize_space_time: bool = None,
|
51 |
+
agg_space_module: str = None,
|
52 |
+
agg_time_module: str = None,
|
53 |
+
add_global_repr: bool = True,
|
54 |
+
agg_segments_module: str = None,
|
55 |
+
max_segments: int = None,
|
56 |
+
):
|
57 |
+
self.extract_features = extract_features
|
58 |
+
self.ckpt_path = ckpt_path
|
59 |
+
self.factorize_space_time = factorize_space_time
|
60 |
+
|
61 |
+
if self.ckpt_path is not None:
|
62 |
+
check_if_file_exists_else_download(self.ckpt_path, FILE2URL)
|
63 |
+
ckpt = torch.load(self.ckpt_path, map_location='cpu')
|
64 |
+
mformer_ckpt2cfg = {
|
65 |
+
'ssv2_motionformer_224_16x4.pyth': 'motionformer_224_16x4.yaml',
|
66 |
+
'ssv2_joint_224_16x4.pyth': 'joint_224_16x4.yaml',
|
67 |
+
'ssv2_divided_224_16x4.pyth': 'divided_224_16x4.yaml',
|
68 |
+
}
|
69 |
+
# init from motionformer ckpt or from our Stage I ckpt
|
70 |
+
# depending on whether the feat extractor was pre-trained on AVCLIPMoCo or not, we need to
|
71 |
+
# load the state dict differently
|
72 |
+
was_pt_on_avclip = self.ckpt_path.endswith(
|
73 |
+
'.pt') # checks if it is a stage I ckpt (FIXME: a bit generic)
|
74 |
+
if self.ckpt_path.endswith(tuple(mformer_ckpt2cfg.keys())):
|
75 |
+
cfg_fname = mformer_ckpt2cfg[Path(self.ckpt_path).name]
|
76 |
+
elif was_pt_on_avclip:
|
77 |
+
# TODO: this is a hack, we should be able to get the cfg from the ckpt (earlier ckpt didn't have it)
|
78 |
+
s1_cfg = ckpt.get('args', None) # Stage I cfg
|
79 |
+
if s1_cfg is not None:
|
80 |
+
s1_vfeat_extractor_ckpt_path = s1_cfg.model.params.vfeat_extractor.params.ckpt_path
|
81 |
+
# if the stage I ckpt was initialized from a motionformer ckpt or train from scratch
|
82 |
+
if s1_vfeat_extractor_ckpt_path is not None:
|
83 |
+
cfg_fname = mformer_ckpt2cfg[Path(s1_vfeat_extractor_ckpt_path).name]
|
84 |
+
else:
|
85 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
86 |
+
else:
|
87 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
88 |
+
else:
|
89 |
+
raise ValueError(f'ckpt_path {self.ckpt_path} is not supported.')
|
90 |
+
else:
|
91 |
+
was_pt_on_avclip = False
|
92 |
+
cfg_fname = 'divided_224_16x4.yaml'
|
93 |
+
# logging.info(f'No ckpt_path provided, using {cfg_fname} config.')
|
94 |
+
|
95 |
+
if cfg_fname in ['motionformer_224_16x4.yaml', 'divided_224_16x4.yaml']:
|
96 |
+
pos_emb_type = 'separate'
|
97 |
+
elif cfg_fname == 'joint_224_16x4.yaml':
|
98 |
+
pos_emb_type = 'joint'
|
99 |
+
|
100 |
+
self.mformer_cfg_path = Path(__file__).absolute().parent / cfg_fname
|
101 |
+
|
102 |
+
check_if_file_exists_else_download(self.mformer_cfg_path, FILE2URL)
|
103 |
+
mformer_cfg = OmegaConf.load(self.mformer_cfg_path)
|
104 |
+
logging.info(f'Loading MotionFormer config from {self.mformer_cfg_path.absolute()}')
|
105 |
+
|
106 |
+
# patch the cfg (from the default cfg defined in the repo `Motionformer/slowfast/config/defaults.py`)
|
107 |
+
mformer_cfg.VIT.ATTN_DROPOUT = 0.0
|
108 |
+
mformer_cfg.VIT.POS_EMBED = pos_emb_type
|
109 |
+
mformer_cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE = True
|
110 |
+
mformer_cfg.VIT.APPROX_ATTN_TYPE = 'none' # guessing
|
111 |
+
mformer_cfg.VIT.APPROX_ATTN_DIM = 64 # from ckpt['cfg']
|
112 |
+
|
113 |
+
# finally init VisionTransformer with the cfg
|
114 |
+
super().__init__(mformer_cfg)
|
115 |
+
|
116 |
+
# load the ckpt now if ckpt is provided and not from AVCLIPMoCo-pretrained ckpt
|
117 |
+
if (self.ckpt_path is not None) and (not was_pt_on_avclip):
|
118 |
+
_ckpt_load_status = self.load_state_dict(ckpt['model_state'], strict=False)
|
119 |
+
if len(_ckpt_load_status.missing_keys) > 0 or len(
|
120 |
+
_ckpt_load_status.unexpected_keys) > 0:
|
121 |
+
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed.' \
|
122 |
+
f'Missing keys: {_ckpt_load_status.missing_keys}, ' \
|
123 |
+
f'Unexpected keys: {_ckpt_load_status.unexpected_keys}')
|
124 |
+
else:
|
125 |
+
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
126 |
+
|
127 |
+
if self.extract_features:
|
128 |
+
assert isinstance(self.norm,
|
129 |
+
nn.LayerNorm), 'early x[:, 1:, :] may not be safe for per-tr weights'
|
130 |
+
# pre-logits are Sequential(nn.Linear(emb, emd), act) and `act` is tanh but see the logger
|
131 |
+
self.pre_logits = nn.Identity()
|
132 |
+
# we don't need the classification head (saving memory)
|
133 |
+
self.head = nn.Identity()
|
134 |
+
self.head_drop = nn.Identity()
|
135 |
+
# avoiding code duplication (used only if agg_*_module is TransformerEncoderLayer)
|
136 |
+
transf_enc_layer_kwargs = dict(
|
137 |
+
d_model=self.embed_dim,
|
138 |
+
nhead=self.num_heads,
|
139 |
+
activation=nn.GELU(),
|
140 |
+
batch_first=True,
|
141 |
+
dim_feedforward=self.mlp_ratio * self.embed_dim,
|
142 |
+
dropout=self.drop_rate,
|
143 |
+
layer_norm_eps=1e-6,
|
144 |
+
norm_first=True,
|
145 |
+
)
|
146 |
+
# define adapters if needed
|
147 |
+
if self.factorize_space_time:
|
148 |
+
if agg_space_module == 'TransformerEncoderLayer':
|
149 |
+
self.spatial_attn_agg = SpatialTransformerEncoderLayer(
|
150 |
+
**transf_enc_layer_kwargs)
|
151 |
+
elif agg_space_module == 'AveragePooling':
|
152 |
+
self.spatial_attn_agg = AveragePooling(avg_pattern='BS D t h w -> BS D t',
|
153 |
+
then_permute_pattern='BS D t -> BS t D')
|
154 |
+
if agg_time_module == 'TransformerEncoderLayer':
|
155 |
+
self.temp_attn_agg = TemporalTransformerEncoderLayer(**transf_enc_layer_kwargs)
|
156 |
+
elif agg_time_module == 'AveragePooling':
|
157 |
+
self.temp_attn_agg = AveragePooling(avg_pattern='BS t D -> BS D')
|
158 |
+
elif 'Identity' in agg_time_module:
|
159 |
+
self.temp_attn_agg = nn.Identity()
|
160 |
+
# define a global aggregation layer (aggregarate over segments)
|
161 |
+
self.add_global_repr = add_global_repr
|
162 |
+
if add_global_repr:
|
163 |
+
if agg_segments_module == 'TransformerEncoderLayer':
|
164 |
+
# we can reuse the same layer as for temporal factorization (B, dim_to_agg, D) -> (B, D)
|
165 |
+
# we need to add pos emb (PE) because previously we added the same PE for each segment
|
166 |
+
pos_max_len = max_segments if max_segments is not None else 16 # 16 = 10sec//0.64sec + 1
|
167 |
+
self.global_attn_agg = TemporalTransformerEncoderLayer(
|
168 |
+
add_pos_emb=True,
|
169 |
+
pos_emb_drop=mformer_cfg.VIT.POS_DROPOUT,
|
170 |
+
pos_max_len=pos_max_len,
|
171 |
+
**transf_enc_layer_kwargs)
|
172 |
+
elif agg_segments_module == 'AveragePooling':
|
173 |
+
self.global_attn_agg = AveragePooling(avg_pattern='B S D -> B D')
|
174 |
+
|
175 |
+
if was_pt_on_avclip:
|
176 |
+
# we need to filter out the state_dict of the AVCLIP model (has both A and V extractors)
|
177 |
+
# and keep only the state_dict of the feat extractor
|
178 |
+
ckpt_weights = dict()
|
179 |
+
for k, v in ckpt['state_dict'].items():
|
180 |
+
if k.startswith(('module.v_encoder.', 'v_encoder.')):
|
181 |
+
k = k.replace('module.', '').replace('v_encoder.', '')
|
182 |
+
ckpt_weights[k] = v
|
183 |
+
_load_status = self.load_state_dict(ckpt_weights, strict=False)
|
184 |
+
if len(_load_status.missing_keys) > 0 or len(_load_status.unexpected_keys) > 0:
|
185 |
+
logging.warning(f'Loading exact vfeat_extractor ckpt from {self.ckpt_path} failed. \n' \
|
186 |
+
f'Missing keys ({len(_load_status.missing_keys)}): ' \
|
187 |
+
f'{_load_status.missing_keys}, \n' \
|
188 |
+
f'Unexpected keys ({len(_load_status.unexpected_keys)}): ' \
|
189 |
+
f'{_load_status.unexpected_keys} \n' \
|
190 |
+
f'temp_attn_agg are expected to be missing if ckpt was pt contrastively.')
|
191 |
+
else:
|
192 |
+
logging.info(f'Loading vfeat_extractor ckpt from {self.ckpt_path} succeeded.')
|
193 |
+
|
194 |
+
# patch_embed is not used in MotionFormer, only patch_embed_3d, because cfg.VIT.PATCH_SIZE_TEMP > 1
|
195 |
+
# but it used to calculate the number of patches, so we need to set keep it
|
196 |
+
self.patch_embed.requires_grad_(False)
|
197 |
+
|
198 |
+
def forward(self, x):
|
199 |
+
'''
|
200 |
+
x is of shape (B, S, C, T, H, W) where S is the number of segments.
|
201 |
+
'''
|
202 |
+
# Batch, Segments, Channels, T=frames, Height, Width
|
203 |
+
B, S, C, T, H, W = x.shape
|
204 |
+
# Motionformer expects a tensor of shape (1, B, C, T, H, W).
|
205 |
+
# The first dimension (1) is a dummy dimension to make the input tensor and won't be used:
|
206 |
+
# see `video_model_builder.video_input`.
|
207 |
+
# x = x.unsqueeze(0) # (1, B, S, C, T, H, W)
|
208 |
+
|
209 |
+
orig_shape = (B, S, C, T, H, W)
|
210 |
+
x = x.view(B * S, C, T, H, W) # flatten batch and segments
|
211 |
+
x = self.forward_segments(x, orig_shape=orig_shape)
|
212 |
+
# unpack the segments (using rest dimensions to support different shapes e.g. (BS, D) or (BS, t, D))
|
213 |
+
x = x.view(B, S, *x.shape[1:])
|
214 |
+
# x is now of shape (B*S, D) or (B*S, t, D) if `self.temp_attn_agg` is `Identity`
|
215 |
+
|
216 |
+
return x # x is (B, S, ...)
|
217 |
+
|
218 |
+
def forward_segments(self, x, orig_shape: tuple) -> torch.Tensor:
|
219 |
+
'''x is of shape (1, BS, C, T, H, W) where S is the number of segments.'''
|
220 |
+
x, x_mask = self.forward_features(x)
|
221 |
+
|
222 |
+
assert self.extract_features
|
223 |
+
|
224 |
+
# (BS, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
225 |
+
x = x[:,
|
226 |
+
1:, :] # without the CLS token for efficiency (should be safe for LayerNorm and FC)
|
227 |
+
x = self.norm(x)
|
228 |
+
x = self.pre_logits(x)
|
229 |
+
if self.factorize_space_time:
|
230 |
+
x = self.restore_spatio_temp_dims(x, orig_shape) # (B*S, D, t, h, w) <- (B*S, t*h*w, D)
|
231 |
+
|
232 |
+
x = self.spatial_attn_agg(x, x_mask) # (B*S, t, D)
|
233 |
+
x = self.temp_attn_agg(
|
234 |
+
x) # (B*S, D) or (BS, t, D) if `self.temp_attn_agg` is `Identity`
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
def restore_spatio_temp_dims(self, feats: torch.Tensor, orig_shape: tuple) -> torch.Tensor:
|
239 |
+
'''
|
240 |
+
feats are of shape (B*S, T, D) where T = 1 + (224 // 16) * (224 // 16) * 8
|
241 |
+
Our goal is to make them of shape (B*S, t, h, w, D) where h, w are the spatial dimensions.
|
242 |
+
From `self.patch_embed_3d`, it follows that we could reshape feats with:
|
243 |
+
`feats.transpose(1, 2).view(B*S, D, t, h, w)`
|
244 |
+
'''
|
245 |
+
B, S, C, T, H, W = orig_shape
|
246 |
+
D = self.embed_dim
|
247 |
+
|
248 |
+
# num patches in each dimension
|
249 |
+
t = T // self.patch_embed_3d.z_block_size
|
250 |
+
h = self.patch_embed_3d.height
|
251 |
+
w = self.patch_embed_3d.width
|
252 |
+
|
253 |
+
feats = feats.permute(0, 2, 1) # (B*S, D, T)
|
254 |
+
feats = feats.view(B * S, D, t, h, w) # (B*S, D, t, h, w)
|
255 |
+
|
256 |
+
return feats
|
257 |
+
|
258 |
+
|
259 |
+
class BaseEncoderLayer(nn.TransformerEncoderLayer):
|
260 |
+
'''
|
261 |
+
This is a wrapper around nn.TransformerEncoderLayer that adds a CLS token
|
262 |
+
to the sequence and outputs the CLS token's representation.
|
263 |
+
This base class parents both SpatialEncoderLayer and TemporalEncoderLayer for the RGB stream
|
264 |
+
and the FrequencyEncoderLayer and TemporalEncoderLayer for the audio stream stream.
|
265 |
+
We also, optionally, add a positional embedding to the input sequence which
|
266 |
+
allows to reuse it for global aggregation (of segments) for both streams.
|
267 |
+
'''
|
268 |
+
|
269 |
+
def __init__(self,
|
270 |
+
add_pos_emb: bool = False,
|
271 |
+
pos_emb_drop: float = None,
|
272 |
+
pos_max_len: int = None,
|
273 |
+
*args_transformer_enc,
|
274 |
+
**kwargs_transformer_enc):
|
275 |
+
super().__init__(*args_transformer_enc, **kwargs_transformer_enc)
|
276 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.self_attn.embed_dim))
|
277 |
+
trunc_normal_(self.cls_token, std=.02)
|
278 |
+
|
279 |
+
# add positional embedding
|
280 |
+
self.add_pos_emb = add_pos_emb
|
281 |
+
if add_pos_emb:
|
282 |
+
self.pos_max_len = 1 + pos_max_len # +1 (for CLS)
|
283 |
+
self.pos_emb = nn.Parameter(torch.zeros(1, self.pos_max_len, self.self_attn.embed_dim))
|
284 |
+
self.pos_drop = nn.Dropout(pos_emb_drop)
|
285 |
+
trunc_normal_(self.pos_emb, std=.02)
|
286 |
+
|
287 |
+
self.apply(self._init_weights)
|
288 |
+
|
289 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None):
|
290 |
+
''' x is of shape (B, N, D); if provided x_mask is of shape (B, N)'''
|
291 |
+
batch_dim = x.shape[0]
|
292 |
+
|
293 |
+
# add CLS token
|
294 |
+
cls_tokens = self.cls_token.expand(batch_dim, -1, -1) # expanding to match batch dimension
|
295 |
+
x = torch.cat((cls_tokens, x), dim=-2) # (batch_dim, 1+seq_len, D)
|
296 |
+
if x_mask is not None:
|
297 |
+
cls_mask = torch.ones((batch_dim, 1), dtype=torch.bool,
|
298 |
+
device=x_mask.device) # 1=keep; 0=mask
|
299 |
+
x_mask_w_cls = torch.cat((cls_mask, x_mask), dim=-1) # (batch_dim, 1+seq_len)
|
300 |
+
B, N = x_mask_w_cls.shape
|
301 |
+
# torch expects (N, N) or (B*num_heads, N, N) mask (sadness ahead); torch masks
|
302 |
+
x_mask_w_cls = x_mask_w_cls.reshape(B, 1, 1, N)\
|
303 |
+
.expand(-1, self.self_attn.num_heads, N, -1)\
|
304 |
+
.reshape(B * self.self_attn.num_heads, N, N)
|
305 |
+
assert x_mask_w_cls.dtype == x_mask_w_cls.bool().dtype, 'x_mask_w_cls.dtype != bool'
|
306 |
+
x_mask_w_cls = ~x_mask_w_cls # invert mask (1=mask)
|
307 |
+
else:
|
308 |
+
x_mask_w_cls = None
|
309 |
+
|
310 |
+
# add positional embedding
|
311 |
+
if self.add_pos_emb:
|
312 |
+
seq_len = x.shape[
|
313 |
+
1] # (don't even think about moving it before the CLS token concatenation)
|
314 |
+
assert seq_len <= self.pos_max_len, f'Seq len ({seq_len}) > pos_max_len ({self.pos_max_len})'
|
315 |
+
x = x + self.pos_emb[:, :seq_len, :]
|
316 |
+
x = self.pos_drop(x)
|
317 |
+
|
318 |
+
# apply encoder layer (calls nn.TransformerEncoderLayer.forward);
|
319 |
+
x = super().forward(src=x, src_mask=x_mask_w_cls) # (batch_dim, 1+seq_len, D)
|
320 |
+
|
321 |
+
# CLS token is expected to hold spatial information for each frame
|
322 |
+
x = x[:, 0, :] # (batch_dim, D)
|
323 |
+
|
324 |
+
return x
|
325 |
+
|
326 |
+
def _init_weights(self, m):
|
327 |
+
if isinstance(m, nn.Linear):
|
328 |
+
trunc_normal_(m.weight, std=.02)
|
329 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
330 |
+
nn.init.constant_(m.bias, 0)
|
331 |
+
elif isinstance(m, nn.LayerNorm):
|
332 |
+
nn.init.constant_(m.bias, 0)
|
333 |
+
nn.init.constant_(m.weight, 1.0)
|
334 |
+
|
335 |
+
@torch.jit.ignore
|
336 |
+
def no_weight_decay(self):
|
337 |
+
return {'cls_token', 'pos_emb'}
|
338 |
+
|
339 |
+
|
340 |
+
class SpatialTransformerEncoderLayer(BaseEncoderLayer):
|
341 |
+
''' Aggregates spatial dimensions by applying attention individually to each frame. '''
|
342 |
+
|
343 |
+
def __init__(self, *args, **kwargs):
|
344 |
+
super().__init__(*args, **kwargs)
|
345 |
+
|
346 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
347 |
+
''' x is of shape (B*S, D, t, h, w) where S is the number of segments.
|
348 |
+
if specified x_mask (B*S, t, h, w), 0=masked, 1=kept
|
349 |
+
Returns a tensor of shape (B*S, t, D) pooling spatial information for each frame. '''
|
350 |
+
BS, D, t, h, w = x.shape
|
351 |
+
|
352 |
+
# time as a batch dimension and flatten spatial dimensions as sequence
|
353 |
+
x = einops.rearrange(x, 'BS D t h w -> (BS t) (h w) D')
|
354 |
+
# similar to mask
|
355 |
+
if x_mask is not None:
|
356 |
+
x_mask = einops.rearrange(x_mask, 'BS t h w -> (BS t) (h w)')
|
357 |
+
|
358 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
359 |
+
x = super().forward(x=x, x_mask=x_mask) # (B*S*t, D)
|
360 |
+
|
361 |
+
# reshape back to (B*S, t, D)
|
362 |
+
x = einops.rearrange(x, '(BS t) D -> BS t D', BS=BS, t=t)
|
363 |
+
|
364 |
+
# (B*S, t, D)
|
365 |
+
return x
|
366 |
+
|
367 |
+
|
368 |
+
class TemporalTransformerEncoderLayer(BaseEncoderLayer):
|
369 |
+
''' Aggregates temporal dimension with attention. Also used with pos emb as global aggregation
|
370 |
+
in both streams. '''
|
371 |
+
|
372 |
+
def __init__(self, *args, **kwargs):
|
373 |
+
super().__init__(*args, **kwargs)
|
374 |
+
|
375 |
+
def forward(self, x):
|
376 |
+
''' x is of shape (B*S, t, D) where S is the number of segments.
|
377 |
+
Returns a tensor of shape (B*S, D) pooling temporal information. '''
|
378 |
+
BS, t, D = x.shape
|
379 |
+
|
380 |
+
# apply encoder layer (BaseEncoderLayer.forward) - it will add CLS token and output its representation
|
381 |
+
x = super().forward(x) # (B*S, D)
|
382 |
+
|
383 |
+
return x # (B*S, D)
|
384 |
+
|
385 |
+
|
386 |
+
class AveragePooling(nn.Module):
|
387 |
+
|
388 |
+
def __init__(self, avg_pattern: str, then_permute_pattern: str = None) -> None:
|
389 |
+
''' patterns are e.g. "bs t d -> bs d" '''
|
390 |
+
super().__init__()
|
391 |
+
# TODO: need to register them as buffers (but fails because these are strings)
|
392 |
+
self.reduce_fn = 'mean'
|
393 |
+
self.avg_pattern = avg_pattern
|
394 |
+
self.then_permute_pattern = then_permute_pattern
|
395 |
+
|
396 |
+
def forward(self, x: torch.Tensor, x_mask: torch.Tensor = None) -> torch.Tensor:
|
397 |
+
x = einops.reduce(x, self.avg_pattern, self.reduce_fn)
|
398 |
+
if self.then_permute_pattern is not None:
|
399 |
+
x = einops.rearrange(x, self.then_permute_pattern)
|
400 |
+
return x
|
data_utils/ext/synchformer/synchformer.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, Mapping
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
from data_utils.ext.synchformer.motionformer import MotionFormer
|
8 |
+
|
9 |
+
|
10 |
+
class Synchformer(nn.Module):
|
11 |
+
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.vfeat_extractor = MotionFormer(extract_features=True,
|
16 |
+
factorize_space_time=True,
|
17 |
+
agg_space_module='TransformerEncoderLayer',
|
18 |
+
agg_time_module='torch.nn.Identity',
|
19 |
+
add_global_repr=False)
|
20 |
+
|
21 |
+
# self.vfeat_extractor = instantiate_from_config(vfeat_extractor)
|
22 |
+
# self.afeat_extractor = instantiate_from_config(afeat_extractor)
|
23 |
+
# # bridging the s3d latent dim (1024) into what is specified in the config
|
24 |
+
# # to match e.g. the transformer dim
|
25 |
+
# self.vproj = instantiate_from_config(vproj)
|
26 |
+
# self.aproj = instantiate_from_config(aproj)
|
27 |
+
# self.transformer = instantiate_from_config(transformer)
|
28 |
+
|
29 |
+
def forward(self, vis):
|
30 |
+
B, S, Tv, C, H, W = vis.shape
|
31 |
+
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
32 |
+
# feat extractors return a tuple of segment-level and global features (ignored for sync)
|
33 |
+
# (B, S, tv, D), e.g. (B, 7, 8, 768)
|
34 |
+
vis = self.vfeat_extractor(vis)
|
35 |
+
return vis
|
36 |
+
|
37 |
+
def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True):
|
38 |
+
# discard all entries except vfeat_extractor
|
39 |
+
sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')}
|
40 |
+
|
41 |
+
return super().load_state_dict(sd, strict)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
model = Synchformer().cuda().eval()
|
46 |
+
sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True)
|
47 |
+
model.load_state_dict(sd)
|
48 |
+
|
49 |
+
vid = torch.randn(2, 7, 16, 3, 224, 224).cuda()
|
50 |
+
features = model.extract_vfeats(vid, for_loop=False).detach().cpu()
|
51 |
+
print(features.shape)
|
52 |
+
|
53 |
+
# extract and save the state dict only
|
54 |
+
# sd = torch.load('./ext_weights/sync_model_audioset.pt')['model']
|
55 |
+
# torch.save(sd, './ext_weights/synchformer_state_dict.pth')
|
data_utils/ext/synchformer/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hashlib import md5
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
PARENT_LINK = 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a'
|
8 |
+
FNAME2LINK = {
|
9 |
+
# S3: Synchability: AudioSet (run 2)
|
10 |
+
'24-01-22T20-34-52.pt':
|
11 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/24-01-22T20-34-52.pt',
|
12 |
+
'cfg-24-01-22T20-34-52.yaml':
|
13 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-22T20-34-52/cfg-24-01-22T20-34-52.yaml',
|
14 |
+
# S2: Synchformer: AudioSet (run 2)
|
15 |
+
'24-01-04T16-39-21.pt':
|
16 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/24-01-04T16-39-21.pt',
|
17 |
+
'cfg-24-01-04T16-39-21.yaml':
|
18 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-04T16-39-21/cfg-24-01-04T16-39-21.yaml',
|
19 |
+
# S2: Synchformer: AudioSet (run 1)
|
20 |
+
'23-08-28T11-23-23.pt':
|
21 |
+
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/23-08-28T11-23-23.pt',
|
22 |
+
'cfg-23-08-28T11-23-23.yaml':
|
23 |
+
f'{PARENT_LINK}/sync/sync_models/23-08-28T11-23-23/cfg-23-08-28T11-23-23.yaml',
|
24 |
+
# S2: Synchformer: LRS3 (run 2)
|
25 |
+
'23-12-23T18-33-57.pt':
|
26 |
+
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/23-12-23T18-33-57.pt',
|
27 |
+
'cfg-23-12-23T18-33-57.yaml':
|
28 |
+
f'{PARENT_LINK}/sync/sync_models/23-12-23T18-33-57/cfg-23-12-23T18-33-57.yaml',
|
29 |
+
# S2: Synchformer: VGS (run 2)
|
30 |
+
'24-01-02T10-00-53.pt':
|
31 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/24-01-02T10-00-53.pt',
|
32 |
+
'cfg-24-01-02T10-00-53.yaml':
|
33 |
+
f'{PARENT_LINK}/sync/sync_models/24-01-02T10-00-53/cfg-24-01-02T10-00-53.yaml',
|
34 |
+
# SparseSync: ft VGGSound-Full
|
35 |
+
'22-09-21T21-00-52.pt':
|
36 |
+
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/22-09-21T21-00-52.pt',
|
37 |
+
'cfg-22-09-21T21-00-52.yaml':
|
38 |
+
f'{PARENT_LINK}/sync/sync_models/22-09-21T21-00-52/cfg-22-09-21T21-00-52.yaml',
|
39 |
+
# SparseSync: ft VGGSound-Sparse
|
40 |
+
'22-07-28T15-49-45.pt':
|
41 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/22-07-28T15-49-45.pt',
|
42 |
+
'cfg-22-07-28T15-49-45.yaml':
|
43 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-28T15-49-45/cfg-22-07-28T15-49-45.yaml',
|
44 |
+
# SparseSync: only pt on LRS3
|
45 |
+
'22-07-13T22-25-49.pt':
|
46 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/22-07-13T22-25-49.pt',
|
47 |
+
'cfg-22-07-13T22-25-49.yaml':
|
48 |
+
f'{PARENT_LINK}/sync/sync_models/22-07-13T22-25-49/cfg-22-07-13T22-25-49.yaml',
|
49 |
+
# SparseSync: feature extractors
|
50 |
+
'ResNetAudio-22-08-04T09-51-04.pt':
|
51 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-04T09-51-04.pt', # 2s
|
52 |
+
'ResNetAudio-22-08-03T23-14-49.pt':
|
53 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-49.pt', # 3s
|
54 |
+
'ResNetAudio-22-08-03T23-14-28.pt':
|
55 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-08-03T23-14-28.pt', # 4s
|
56 |
+
'ResNetAudio-22-06-24T08-10-33.pt':
|
57 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T08-10-33.pt', # 5s
|
58 |
+
'ResNetAudio-22-06-24T17-31-07.pt':
|
59 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T17-31-07.pt', # 6s
|
60 |
+
'ResNetAudio-22-06-24T23-57-11.pt':
|
61 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-24T23-57-11.pt', # 7s
|
62 |
+
'ResNetAudio-22-06-25T04-35-42.pt':
|
63 |
+
f'{PARENT_LINK}/sync/ResNetAudio-22-06-25T04-35-42.pt', # 8s
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def check_if_file_exists_else_download(path, fname2link=FNAME2LINK, chunk_size=1024):
|
68 |
+
'''Checks if file exists, if not downloads it from the link to the path'''
|
69 |
+
path = Path(path)
|
70 |
+
if not path.exists():
|
71 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
72 |
+
link = fname2link.get(path.name, None)
|
73 |
+
if link is None:
|
74 |
+
raise ValueError(f'Cant find the checkpoint file: {path}.',
|
75 |
+
f'Please download it manually and ensure the path exists.')
|
76 |
+
with requests.get(fname2link[path.name], stream=True) as r:
|
77 |
+
total_size = int(r.headers.get('content-length', 0))
|
78 |
+
with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
|
79 |
+
with open(path, 'wb') as f:
|
80 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
81 |
+
if data:
|
82 |
+
f.write(data)
|
83 |
+
pbar.update(chunk_size)
|
84 |
+
|
85 |
+
|
86 |
+
def get_md5sum(path):
|
87 |
+
hash_md5 = md5()
|
88 |
+
with open(path, 'rb') as f:
|
89 |
+
for chunk in iter(lambda: f.read(4096 * 8), b''):
|
90 |
+
hash_md5.update(chunk)
|
91 |
+
md5sum = hash_md5.hexdigest()
|
92 |
+
return md5sum
|
data_utils/ext/synchformer/video_model_builder.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
# Modified Model definition
|
5 |
+
|
6 |
+
from collections import OrderedDict
|
7 |
+
from functools import partial
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from timm.layers import trunc_normal_
|
12 |
+
|
13 |
+
from data_utils.ext.synchformer import vit_helper
|
14 |
+
|
15 |
+
|
16 |
+
class VisionTransformer(nn.Module):
|
17 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage """
|
18 |
+
|
19 |
+
def __init__(self, cfg):
|
20 |
+
super().__init__()
|
21 |
+
self.img_size = cfg.DATA.TRAIN_CROP_SIZE
|
22 |
+
self.patch_size = cfg.VIT.PATCH_SIZE
|
23 |
+
self.in_chans = cfg.VIT.CHANNELS
|
24 |
+
if cfg.TRAIN.DATASET == "Epickitchens":
|
25 |
+
self.num_classes = [97, 300]
|
26 |
+
else:
|
27 |
+
self.num_classes = cfg.MODEL.NUM_CLASSES
|
28 |
+
self.embed_dim = cfg.VIT.EMBED_DIM
|
29 |
+
self.depth = cfg.VIT.DEPTH
|
30 |
+
self.num_heads = cfg.VIT.NUM_HEADS
|
31 |
+
self.mlp_ratio = cfg.VIT.MLP_RATIO
|
32 |
+
self.qkv_bias = cfg.VIT.QKV_BIAS
|
33 |
+
self.drop_rate = cfg.VIT.DROP
|
34 |
+
self.drop_path_rate = cfg.VIT.DROP_PATH
|
35 |
+
self.head_dropout = cfg.VIT.HEAD_DROPOUT
|
36 |
+
self.video_input = cfg.VIT.VIDEO_INPUT
|
37 |
+
self.temporal_resolution = cfg.VIT.TEMPORAL_RESOLUTION
|
38 |
+
self.use_mlp = cfg.VIT.USE_MLP
|
39 |
+
self.num_features = self.embed_dim
|
40 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
41 |
+
self.attn_drop_rate = cfg.VIT.ATTN_DROPOUT
|
42 |
+
self.head_act = cfg.VIT.HEAD_ACT
|
43 |
+
self.cfg = cfg
|
44 |
+
|
45 |
+
# Patch Embedding
|
46 |
+
self.patch_embed = vit_helper.PatchEmbed(img_size=224,
|
47 |
+
patch_size=self.patch_size,
|
48 |
+
in_chans=self.in_chans,
|
49 |
+
embed_dim=self.embed_dim)
|
50 |
+
|
51 |
+
# 3D Patch Embedding
|
52 |
+
self.patch_embed_3d = vit_helper.PatchEmbed3D(img_size=self.img_size,
|
53 |
+
temporal_resolution=self.temporal_resolution,
|
54 |
+
patch_size=self.patch_size,
|
55 |
+
in_chans=self.in_chans,
|
56 |
+
embed_dim=self.embed_dim,
|
57 |
+
z_block_size=self.cfg.VIT.PATCH_SIZE_TEMP)
|
58 |
+
self.patch_embed_3d.proj.weight.data = torch.zeros_like(
|
59 |
+
self.patch_embed_3d.proj.weight.data)
|
60 |
+
|
61 |
+
# Number of patches
|
62 |
+
if self.video_input:
|
63 |
+
num_patches = self.patch_embed.num_patches * self.temporal_resolution
|
64 |
+
else:
|
65 |
+
num_patches = self.patch_embed.num_patches
|
66 |
+
self.num_patches = num_patches
|
67 |
+
|
68 |
+
# CLS token
|
69 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
70 |
+
trunc_normal_(self.cls_token, std=.02)
|
71 |
+
|
72 |
+
# Positional embedding
|
73 |
+
self.pos_embed = nn.Parameter(
|
74 |
+
torch.zeros(1, self.patch_embed.num_patches + 1, self.embed_dim))
|
75 |
+
self.pos_drop = nn.Dropout(p=cfg.VIT.POS_DROPOUT)
|
76 |
+
trunc_normal_(self.pos_embed, std=.02)
|
77 |
+
|
78 |
+
if self.cfg.VIT.POS_EMBED == "joint":
|
79 |
+
self.st_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
|
80 |
+
trunc_normal_(self.st_embed, std=.02)
|
81 |
+
elif self.cfg.VIT.POS_EMBED == "separate":
|
82 |
+
self.temp_embed = nn.Parameter(torch.zeros(1, self.temporal_resolution, self.embed_dim))
|
83 |
+
|
84 |
+
# Layer Blocks
|
85 |
+
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
|
86 |
+
if self.cfg.VIT.ATTN_LAYER == "divided":
|
87 |
+
self.blocks = nn.ModuleList([
|
88 |
+
vit_helper.DividedSpaceTimeBlock(
|
89 |
+
attn_type=cfg.VIT.ATTN_LAYER,
|
90 |
+
dim=self.embed_dim,
|
91 |
+
num_heads=self.num_heads,
|
92 |
+
mlp_ratio=self.mlp_ratio,
|
93 |
+
qkv_bias=self.qkv_bias,
|
94 |
+
drop=self.drop_rate,
|
95 |
+
attn_drop=self.attn_drop_rate,
|
96 |
+
drop_path=dpr[i],
|
97 |
+
norm_layer=norm_layer,
|
98 |
+
) for i in range(self.depth)
|
99 |
+
])
|
100 |
+
else:
|
101 |
+
self.blocks = nn.ModuleList([
|
102 |
+
vit_helper.Block(attn_type=cfg.VIT.ATTN_LAYER,
|
103 |
+
dim=self.embed_dim,
|
104 |
+
num_heads=self.num_heads,
|
105 |
+
mlp_ratio=self.mlp_ratio,
|
106 |
+
qkv_bias=self.qkv_bias,
|
107 |
+
drop=self.drop_rate,
|
108 |
+
attn_drop=self.attn_drop_rate,
|
109 |
+
drop_path=dpr[i],
|
110 |
+
norm_layer=norm_layer,
|
111 |
+
use_original_code=self.cfg.VIT.USE_ORIGINAL_TRAJ_ATTN_CODE)
|
112 |
+
for i in range(self.depth)
|
113 |
+
])
|
114 |
+
self.norm = norm_layer(self.embed_dim)
|
115 |
+
|
116 |
+
# MLP head
|
117 |
+
if self.use_mlp:
|
118 |
+
hidden_dim = self.embed_dim
|
119 |
+
if self.head_act == 'tanh':
|
120 |
+
# logging.info("Using TanH activation in MLP")
|
121 |
+
act = nn.Tanh()
|
122 |
+
elif self.head_act == 'gelu':
|
123 |
+
# logging.info("Using GELU activation in MLP")
|
124 |
+
act = nn.GELU()
|
125 |
+
else:
|
126 |
+
# logging.info("Using ReLU activation in MLP")
|
127 |
+
act = nn.ReLU()
|
128 |
+
self.pre_logits = nn.Sequential(
|
129 |
+
OrderedDict([
|
130 |
+
('fc', nn.Linear(self.embed_dim, hidden_dim)),
|
131 |
+
('act', act),
|
132 |
+
]))
|
133 |
+
else:
|
134 |
+
self.pre_logits = nn.Identity()
|
135 |
+
|
136 |
+
# Classifier Head
|
137 |
+
self.head_drop = nn.Dropout(p=self.head_dropout)
|
138 |
+
if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
139 |
+
for a, i in enumerate(range(len(self.num_classes))):
|
140 |
+
setattr(self, "head%d" % a, nn.Linear(self.embed_dim, self.num_classes[i]))
|
141 |
+
else:
|
142 |
+
self.head = nn.Linear(self.embed_dim,
|
143 |
+
self.num_classes) if self.num_classes > 0 else nn.Identity()
|
144 |
+
|
145 |
+
# Initialize weights
|
146 |
+
self.apply(self._init_weights)
|
147 |
+
|
148 |
+
def _init_weights(self, m):
|
149 |
+
if isinstance(m, nn.Linear):
|
150 |
+
trunc_normal_(m.weight, std=.02)
|
151 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
152 |
+
nn.init.constant_(m.bias, 0)
|
153 |
+
elif isinstance(m, nn.LayerNorm):
|
154 |
+
nn.init.constant_(m.bias, 0)
|
155 |
+
nn.init.constant_(m.weight, 1.0)
|
156 |
+
|
157 |
+
@torch.jit.ignore
|
158 |
+
def no_weight_decay(self):
|
159 |
+
if self.cfg.VIT.POS_EMBED == "joint":
|
160 |
+
return {'pos_embed', 'cls_token', 'st_embed'}
|
161 |
+
else:
|
162 |
+
return {'pos_embed', 'cls_token', 'temp_embed'}
|
163 |
+
|
164 |
+
def get_classifier(self):
|
165 |
+
return self.head
|
166 |
+
|
167 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
168 |
+
self.num_classes = num_classes
|
169 |
+
self.head = (nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity())
|
170 |
+
|
171 |
+
def forward_features(self, x):
|
172 |
+
# if self.video_input:
|
173 |
+
# x = x[0]
|
174 |
+
B = x.shape[0]
|
175 |
+
|
176 |
+
# Tokenize input
|
177 |
+
# if self.cfg.VIT.PATCH_SIZE_TEMP > 1:
|
178 |
+
# for simplicity of mapping between content dimensions (input x) and token dims (after patching)
|
179 |
+
# we use the same trick as for AST (see modeling_ast.ASTModel.forward for the details):
|
180 |
+
|
181 |
+
# apply patching on input
|
182 |
+
x = self.patch_embed_3d(x)
|
183 |
+
tok_mask = None
|
184 |
+
|
185 |
+
# else:
|
186 |
+
# tok_mask = None
|
187 |
+
# # 2D tokenization
|
188 |
+
# if self.video_input:
|
189 |
+
# x = x.permute(0, 2, 1, 3, 4)
|
190 |
+
# (B, T, C, H, W) = x.shape
|
191 |
+
# x = x.reshape(B * T, C, H, W)
|
192 |
+
|
193 |
+
# x = self.patch_embed(x)
|
194 |
+
|
195 |
+
# if self.video_input:
|
196 |
+
# (B2, T2, D2) = x.shape
|
197 |
+
# x = x.reshape(B, T * T2, D2)
|
198 |
+
|
199 |
+
# Append CLS token
|
200 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
201 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
202 |
+
# if tok_mask is not None:
|
203 |
+
# # prepend 1(=keep) to the mask to account for the CLS token as well
|
204 |
+
# tok_mask = torch.cat((torch.ones_like(tok_mask[:, [0]]), tok_mask), dim=1)
|
205 |
+
|
206 |
+
# Interpolate positinoal embeddings
|
207 |
+
# if self.cfg.DATA.TRAIN_CROP_SIZE != 224:
|
208 |
+
# pos_embed = self.pos_embed
|
209 |
+
# N = pos_embed.shape[1] - 1
|
210 |
+
# npatch = int((x.size(1) - 1) / self.temporal_resolution)
|
211 |
+
# class_emb = pos_embed[:, 0]
|
212 |
+
# pos_embed = pos_embed[:, 1:]
|
213 |
+
# dim = x.shape[-1]
|
214 |
+
# pos_embed = torch.nn.functional.interpolate(
|
215 |
+
# pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
216 |
+
# scale_factor=math.sqrt(npatch / N),
|
217 |
+
# mode='bicubic',
|
218 |
+
# )
|
219 |
+
# pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
220 |
+
# new_pos_embed = torch.cat((class_emb.unsqueeze(0), pos_embed), dim=1)
|
221 |
+
# else:
|
222 |
+
new_pos_embed = self.pos_embed
|
223 |
+
npatch = self.patch_embed.num_patches
|
224 |
+
|
225 |
+
# Add positional embeddings to input
|
226 |
+
if self.video_input:
|
227 |
+
if self.cfg.VIT.POS_EMBED == "separate":
|
228 |
+
cls_embed = self.pos_embed[:, 0, :].unsqueeze(1)
|
229 |
+
tile_pos_embed = new_pos_embed[:, 1:, :].repeat(1, self.temporal_resolution, 1)
|
230 |
+
tile_temporal_embed = self.temp_embed.repeat_interleave(npatch, 1)
|
231 |
+
total_pos_embed = tile_pos_embed + tile_temporal_embed
|
232 |
+
total_pos_embed = torch.cat([cls_embed, total_pos_embed], dim=1)
|
233 |
+
x = x + total_pos_embed
|
234 |
+
elif self.cfg.VIT.POS_EMBED == "joint":
|
235 |
+
x = x + self.st_embed
|
236 |
+
else:
|
237 |
+
# image input
|
238 |
+
x = x + new_pos_embed
|
239 |
+
|
240 |
+
# Apply positional dropout
|
241 |
+
x = self.pos_drop(x)
|
242 |
+
|
243 |
+
# Encoding using transformer layers
|
244 |
+
for i, blk in enumerate(self.blocks):
|
245 |
+
x = blk(x,
|
246 |
+
seq_len=npatch,
|
247 |
+
num_frames=self.temporal_resolution,
|
248 |
+
approx=self.cfg.VIT.APPROX_ATTN_TYPE,
|
249 |
+
num_landmarks=self.cfg.VIT.APPROX_ATTN_DIM,
|
250 |
+
tok_mask=tok_mask)
|
251 |
+
|
252 |
+
### v-iashin: I moved it to the forward pass
|
253 |
+
# x = self.norm(x)[:, 0]
|
254 |
+
# x = self.pre_logits(x)
|
255 |
+
###
|
256 |
+
return x, tok_mask
|
257 |
+
|
258 |
+
# def forward(self, x):
|
259 |
+
# x = self.forward_features(x)
|
260 |
+
# ### v-iashin: here. This should leave the same forward output as before
|
261 |
+
# x = self.norm(x)[:, 0]
|
262 |
+
# x = self.pre_logits(x)
|
263 |
+
# ###
|
264 |
+
# x = self.head_drop(x)
|
265 |
+
# if isinstance(self.num_classes, (list, )) and len(self.num_classes) > 1:
|
266 |
+
# output = []
|
267 |
+
# for head in range(len(self.num_classes)):
|
268 |
+
# x_out = getattr(self, "head%d" % head)(x)
|
269 |
+
# if not self.training:
|
270 |
+
# x_out = torch.nn.functional.softmax(x_out, dim=-1)
|
271 |
+
# output.append(x_out)
|
272 |
+
# return output
|
273 |
+
# else:
|
274 |
+
# x = self.head(x)
|
275 |
+
# if not self.training:
|
276 |
+
# x = torch.nn.functional.softmax(x, dim=-1)
|
277 |
+
# return x
|
data_utils/ext/synchformer/vit_helper.py
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
3 |
+
# Copyright 2020 Ross Wightman
|
4 |
+
# Modified Model definition
|
5 |
+
"""Video models."""
|
6 |
+
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
from timm.layers import to_2tuple
|
13 |
+
from torch import einsum
|
14 |
+
from torch.nn import functional as F
|
15 |
+
|
16 |
+
default_cfgs = {
|
17 |
+
'vit_1k':
|
18 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
|
19 |
+
'vit_1k_large':
|
20 |
+
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
|
21 |
+
}
|
22 |
+
|
23 |
+
|
24 |
+
def qkv_attn(q, k, v, tok_mask: torch.Tensor = None):
|
25 |
+
sim = einsum('b i d, b j d -> b i j', q, k)
|
26 |
+
# apply masking if provided, tok_mask is (B*S*H, N): 1s - keep; sim is (B*S*H, H, N, N)
|
27 |
+
if tok_mask is not None:
|
28 |
+
BSH, N = tok_mask.shape
|
29 |
+
sim = sim.masked_fill(tok_mask.view(BSH, 1, N) == 0,
|
30 |
+
float('-inf')) # 1 - broadcasts across N
|
31 |
+
attn = sim.softmax(dim=-1)
|
32 |
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
33 |
+
return out
|
34 |
+
|
35 |
+
|
36 |
+
class DividedAttention(nn.Module):
|
37 |
+
|
38 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
44 |
+
self.proj = nn.Linear(dim, dim)
|
45 |
+
|
46 |
+
# init to zeros
|
47 |
+
self.qkv.weight.data.fill_(0)
|
48 |
+
self.qkv.bias.data.fill_(0)
|
49 |
+
self.proj.weight.data.fill_(1)
|
50 |
+
self.proj.bias.data.fill_(0)
|
51 |
+
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
54 |
+
|
55 |
+
def forward(self, x, einops_from, einops_to, tok_mask: torch.Tensor = None, **einops_dims):
|
56 |
+
# num of heads variable
|
57 |
+
h = self.num_heads
|
58 |
+
|
59 |
+
# project x to q, k, v vaalues
|
60 |
+
q, k, v = self.qkv(x).chunk(3, dim=-1)
|
61 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
62 |
+
if tok_mask is not None:
|
63 |
+
# replicate token mask across heads (b, n) -> (b, h, n) -> (b*h, n) -- same as qkv but w/o d
|
64 |
+
assert len(tok_mask.shape) == 2
|
65 |
+
tok_mask = tok_mask.unsqueeze(1).expand(-1, h, -1).reshape(-1, tok_mask.shape[1])
|
66 |
+
|
67 |
+
# Scale q
|
68 |
+
q *= self.scale
|
69 |
+
|
70 |
+
# Take out cls_q, cls_k, cls_v
|
71 |
+
(cls_q, q_), (cls_k, k_), (cls_v, v_) = map(lambda t: (t[:, 0:1], t[:, 1:]), (q, k, v))
|
72 |
+
# the same for masking
|
73 |
+
if tok_mask is not None:
|
74 |
+
cls_mask, mask_ = tok_mask[:, 0:1], tok_mask[:, 1:]
|
75 |
+
else:
|
76 |
+
cls_mask, mask_ = None, None
|
77 |
+
|
78 |
+
# let CLS token attend to key / values of all patches across time and space
|
79 |
+
cls_out = qkv_attn(cls_q, k, v, tok_mask=tok_mask)
|
80 |
+
|
81 |
+
# rearrange across time or space
|
82 |
+
q_, k_, v_ = map(lambda t: rearrange(t, f'{einops_from} -> {einops_to}', **einops_dims),
|
83 |
+
(q_, k_, v_))
|
84 |
+
|
85 |
+
# expand CLS token keys and values across time or space and concat
|
86 |
+
r = q_.shape[0] // cls_k.shape[0]
|
87 |
+
cls_k, cls_v = map(lambda t: repeat(t, 'b () d -> (b r) () d', r=r), (cls_k, cls_v))
|
88 |
+
|
89 |
+
k_ = torch.cat((cls_k, k_), dim=1)
|
90 |
+
v_ = torch.cat((cls_v, v_), dim=1)
|
91 |
+
|
92 |
+
# the same for masking (if provided)
|
93 |
+
if tok_mask is not None:
|
94 |
+
# since mask does not have the latent dim (d), we need to remove it from einops dims
|
95 |
+
mask_ = rearrange(mask_, f'{einops_from} -> {einops_to}'.replace(' d', ''),
|
96 |
+
**einops_dims)
|
97 |
+
cls_mask = repeat(cls_mask, 'b () -> (b r) ()',
|
98 |
+
r=r) # expand cls_mask across time or space
|
99 |
+
mask_ = torch.cat((cls_mask, mask_), dim=1)
|
100 |
+
|
101 |
+
# attention
|
102 |
+
out = qkv_attn(q_, k_, v_, tok_mask=mask_)
|
103 |
+
|
104 |
+
# merge back time or space
|
105 |
+
out = rearrange(out, f'{einops_to} -> {einops_from}', **einops_dims)
|
106 |
+
|
107 |
+
# concat back the cls token
|
108 |
+
out = torch.cat((cls_out, out), dim=1)
|
109 |
+
|
110 |
+
# merge back the heads
|
111 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
112 |
+
|
113 |
+
## to out
|
114 |
+
x = self.proj(out)
|
115 |
+
x = self.proj_drop(x)
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
class DividedSpaceTimeBlock(nn.Module):
|
120 |
+
|
121 |
+
def __init__(self,
|
122 |
+
dim=768,
|
123 |
+
num_heads=12,
|
124 |
+
attn_type='divided',
|
125 |
+
mlp_ratio=4.,
|
126 |
+
qkv_bias=False,
|
127 |
+
drop=0.,
|
128 |
+
attn_drop=0.,
|
129 |
+
drop_path=0.,
|
130 |
+
act_layer=nn.GELU,
|
131 |
+
norm_layer=nn.LayerNorm):
|
132 |
+
super().__init__()
|
133 |
+
|
134 |
+
self.einops_from_space = 'b (f n) d'
|
135 |
+
self.einops_to_space = '(b f) n d'
|
136 |
+
self.einops_from_time = 'b (f n) d'
|
137 |
+
self.einops_to_time = '(b n) f d'
|
138 |
+
|
139 |
+
self.norm1 = norm_layer(dim)
|
140 |
+
|
141 |
+
self.attn = DividedAttention(dim,
|
142 |
+
num_heads=num_heads,
|
143 |
+
qkv_bias=qkv_bias,
|
144 |
+
attn_drop=attn_drop,
|
145 |
+
proj_drop=drop)
|
146 |
+
|
147 |
+
self.timeattn = DividedAttention(dim,
|
148 |
+
num_heads=num_heads,
|
149 |
+
qkv_bias=qkv_bias,
|
150 |
+
attn_drop=attn_drop,
|
151 |
+
proj_drop=drop)
|
152 |
+
|
153 |
+
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
154 |
+
self.drop_path = nn.Identity()
|
155 |
+
self.norm2 = norm_layer(dim)
|
156 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
157 |
+
self.mlp = Mlp(in_features=dim,
|
158 |
+
hidden_features=mlp_hidden_dim,
|
159 |
+
act_layer=act_layer,
|
160 |
+
drop=drop)
|
161 |
+
self.norm3 = norm_layer(dim)
|
162 |
+
|
163 |
+
def forward(self,
|
164 |
+
x,
|
165 |
+
seq_len=196,
|
166 |
+
num_frames=8,
|
167 |
+
approx='none',
|
168 |
+
num_landmarks=128,
|
169 |
+
tok_mask: torch.Tensor = None):
|
170 |
+
time_output = self.timeattn(self.norm3(x),
|
171 |
+
self.einops_from_time,
|
172 |
+
self.einops_to_time,
|
173 |
+
n=seq_len,
|
174 |
+
tok_mask=tok_mask)
|
175 |
+
time_residual = x + time_output
|
176 |
+
|
177 |
+
space_output = self.attn(self.norm1(time_residual),
|
178 |
+
self.einops_from_space,
|
179 |
+
self.einops_to_space,
|
180 |
+
f=num_frames,
|
181 |
+
tok_mask=tok_mask)
|
182 |
+
space_residual = time_residual + self.drop_path(space_output)
|
183 |
+
|
184 |
+
x = space_residual
|
185 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
186 |
+
return x
|
187 |
+
|
188 |
+
|
189 |
+
class Mlp(nn.Module):
|
190 |
+
|
191 |
+
def __init__(self,
|
192 |
+
in_features,
|
193 |
+
hidden_features=None,
|
194 |
+
out_features=None,
|
195 |
+
act_layer=nn.GELU,
|
196 |
+
drop=0.):
|
197 |
+
super().__init__()
|
198 |
+
out_features = out_features or in_features
|
199 |
+
hidden_features = hidden_features or in_features
|
200 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
201 |
+
self.act = act_layer()
|
202 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
203 |
+
self.drop = nn.Dropout(drop)
|
204 |
+
|
205 |
+
def forward(self, x):
|
206 |
+
x = self.fc1(x)
|
207 |
+
x = self.act(x)
|
208 |
+
x = self.drop(x)
|
209 |
+
x = self.fc2(x)
|
210 |
+
x = self.drop(x)
|
211 |
+
return x
|
212 |
+
|
213 |
+
|
214 |
+
class PatchEmbed(nn.Module):
|
215 |
+
""" Image to Patch Embedding
|
216 |
+
"""
|
217 |
+
|
218 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
219 |
+
super().__init__()
|
220 |
+
img_size = img_size if type(img_size) is tuple else to_2tuple(img_size)
|
221 |
+
patch_size = img_size if type(patch_size) is tuple else to_2tuple(patch_size)
|
222 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
223 |
+
self.img_size = img_size
|
224 |
+
self.patch_size = patch_size
|
225 |
+
self.num_patches = num_patches
|
226 |
+
|
227 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
228 |
+
|
229 |
+
def forward(self, x):
|
230 |
+
B, C, H, W = x.shape
|
231 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
232 |
+
return x
|
233 |
+
|
234 |
+
|
235 |
+
class PatchEmbed3D(nn.Module):
|
236 |
+
""" Image to Patch Embedding """
|
237 |
+
|
238 |
+
def __init__(self,
|
239 |
+
img_size=224,
|
240 |
+
temporal_resolution=4,
|
241 |
+
in_chans=3,
|
242 |
+
patch_size=16,
|
243 |
+
z_block_size=2,
|
244 |
+
embed_dim=768,
|
245 |
+
flatten=True):
|
246 |
+
super().__init__()
|
247 |
+
self.height = (img_size // patch_size)
|
248 |
+
self.width = (img_size // patch_size)
|
249 |
+
### v-iashin: these two are incorrect
|
250 |
+
# self.frames = (temporal_resolution // z_block_size)
|
251 |
+
# self.num_patches = self.height * self.width * self.frames
|
252 |
+
self.z_block_size = z_block_size
|
253 |
+
###
|
254 |
+
self.proj = nn.Conv3d(in_chans,
|
255 |
+
embed_dim,
|
256 |
+
kernel_size=(z_block_size, patch_size, patch_size),
|
257 |
+
stride=(z_block_size, patch_size, patch_size))
|
258 |
+
self.flatten = flatten
|
259 |
+
|
260 |
+
def forward(self, x):
|
261 |
+
B, C, T, H, W = x.shape
|
262 |
+
x = self.proj(x)
|
263 |
+
if self.flatten:
|
264 |
+
x = x.flatten(2).transpose(1, 2)
|
265 |
+
return x
|
266 |
+
|
267 |
+
|
268 |
+
class HeadMLP(nn.Module):
|
269 |
+
|
270 |
+
def __init__(self, n_input, n_classes, n_hidden=512, p=0.1):
|
271 |
+
super(HeadMLP, self).__init__()
|
272 |
+
self.n_input = n_input
|
273 |
+
self.n_classes = n_classes
|
274 |
+
self.n_hidden = n_hidden
|
275 |
+
if n_hidden is None:
|
276 |
+
# use linear classifier
|
277 |
+
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
278 |
+
nn.Linear(n_input, n_classes, bias=True))
|
279 |
+
else:
|
280 |
+
# use simple MLP classifier
|
281 |
+
self.block_forward = nn.Sequential(nn.Dropout(p=p),
|
282 |
+
nn.Linear(n_input, n_hidden, bias=True),
|
283 |
+
nn.BatchNorm1d(n_hidden), nn.ReLU(inplace=True),
|
284 |
+
nn.Dropout(p=p),
|
285 |
+
nn.Linear(n_hidden, n_classes, bias=True))
|
286 |
+
print(f"Dropout-NLP: {p}")
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
return self.block_forward(x)
|
290 |
+
|
291 |
+
|
292 |
+
def _conv_filter(state_dict, patch_size=16):
|
293 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
294 |
+
out_dict = {}
|
295 |
+
for k, v in state_dict.items():
|
296 |
+
if 'patch_embed.proj.weight' in k:
|
297 |
+
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
298 |
+
out_dict[k] = v
|
299 |
+
return out_dict
|
300 |
+
|
301 |
+
|
302 |
+
def adapt_input_conv(in_chans, conv_weight, agg='sum'):
|
303 |
+
conv_type = conv_weight.dtype
|
304 |
+
conv_weight = conv_weight.float()
|
305 |
+
O, I, J, K = conv_weight.shape
|
306 |
+
if in_chans == 1:
|
307 |
+
if I > 3:
|
308 |
+
assert conv_weight.shape[1] % 3 == 0
|
309 |
+
# For models with space2depth stems
|
310 |
+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
311 |
+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
312 |
+
else:
|
313 |
+
if agg == 'sum':
|
314 |
+
print("Summing conv1 weights")
|
315 |
+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
316 |
+
else:
|
317 |
+
print("Averaging conv1 weights")
|
318 |
+
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
319 |
+
elif in_chans != 3:
|
320 |
+
if I != 3:
|
321 |
+
raise NotImplementedError('Weight format not supported by conversion.')
|
322 |
+
else:
|
323 |
+
if agg == 'sum':
|
324 |
+
print("Summing conv1 weights")
|
325 |
+
repeat = int(math.ceil(in_chans / 3))
|
326 |
+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
327 |
+
conv_weight *= (3 / float(in_chans))
|
328 |
+
else:
|
329 |
+
print("Averaging conv1 weights")
|
330 |
+
conv_weight = conv_weight.mean(dim=1, keepdim=True)
|
331 |
+
conv_weight = conv_weight.repeat(1, in_chans, 1, 1)
|
332 |
+
conv_weight = conv_weight.to(conv_type)
|
333 |
+
return conv_weight
|
334 |
+
|
335 |
+
|
336 |
+
def load_pretrained(model,
|
337 |
+
cfg=None,
|
338 |
+
num_classes=1000,
|
339 |
+
in_chans=3,
|
340 |
+
filter_fn=None,
|
341 |
+
strict=True,
|
342 |
+
progress=False):
|
343 |
+
# Load state dict
|
344 |
+
assert (f"{cfg.VIT.PRETRAINED_WEIGHTS} not in [vit_1k, vit_1k_large]")
|
345 |
+
state_dict = torch.hub.load_state_dict_from_url(url=default_cfgs[cfg.VIT.PRETRAINED_WEIGHTS])
|
346 |
+
|
347 |
+
if filter_fn is not None:
|
348 |
+
state_dict = filter_fn(state_dict)
|
349 |
+
|
350 |
+
input_convs = 'patch_embed.proj'
|
351 |
+
if input_convs is not None and in_chans != 3:
|
352 |
+
if isinstance(input_convs, str):
|
353 |
+
input_convs = (input_convs, )
|
354 |
+
for input_conv_name in input_convs:
|
355 |
+
weight_name = input_conv_name + '.weight'
|
356 |
+
try:
|
357 |
+
state_dict[weight_name] = adapt_input_conv(in_chans,
|
358 |
+
state_dict[weight_name],
|
359 |
+
agg='avg')
|
360 |
+
print(
|
361 |
+
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)'
|
362 |
+
)
|
363 |
+
except NotImplementedError as e:
|
364 |
+
del state_dict[weight_name]
|
365 |
+
strict = False
|
366 |
+
print(
|
367 |
+
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.'
|
368 |
+
)
|
369 |
+
|
370 |
+
classifier_name = 'head'
|
371 |
+
label_offset = cfg.get('label_offset', 0)
|
372 |
+
pretrain_classes = 1000
|
373 |
+
if num_classes != pretrain_classes:
|
374 |
+
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
375 |
+
del state_dict[classifier_name + '.weight']
|
376 |
+
del state_dict[classifier_name + '.bias']
|
377 |
+
strict = False
|
378 |
+
elif label_offset > 0:
|
379 |
+
# special case for pretrained weights with an extra background class in pretrained weights
|
380 |
+
classifier_weight = state_dict[classifier_name + '.weight']
|
381 |
+
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
382 |
+
classifier_bias = state_dict[classifier_name + '.bias']
|
383 |
+
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
384 |
+
|
385 |
+
loaded_state = state_dict
|
386 |
+
self_state = model.state_dict()
|
387 |
+
all_names = set(self_state.keys())
|
388 |
+
saved_names = set([])
|
389 |
+
for name, param in loaded_state.items():
|
390 |
+
param = param
|
391 |
+
if 'module.' in name:
|
392 |
+
name = name.replace('module.', '')
|
393 |
+
if name in self_state.keys() and param.shape == self_state[name].shape:
|
394 |
+
saved_names.add(name)
|
395 |
+
self_state[name].copy_(param)
|
396 |
+
else:
|
397 |
+
print(f"didnt load: {name} of shape: {param.shape}")
|
398 |
+
print("Missing Keys:")
|
399 |
+
print(all_names - saved_names)
|
data_utils/utils.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Utility functions."""
|
2 |
+
import contextlib
|
3 |
+
import csv
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import pathlib
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def save_args(filename, args):
|
13 |
+
"""Save the command-line arguments."""
|
14 |
+
args_dict = {}
|
15 |
+
for key, value in vars(args).items():
|
16 |
+
if isinstance(value, pathlib.Path):
|
17 |
+
args_dict[key] = str(value)
|
18 |
+
else:
|
19 |
+
args_dict[key] = value
|
20 |
+
save_json(filename, args_dict)
|
21 |
+
|
22 |
+
|
23 |
+
def inverse_dict(d):
|
24 |
+
"""Return the inverse dictionary."""
|
25 |
+
return {v: k for k, v in d.items()}
|
26 |
+
|
27 |
+
|
28 |
+
def save_txt(filename, data):
|
29 |
+
"""Save a list to a TXT file."""
|
30 |
+
with open(filename, "w", encoding="utf8") as f:
|
31 |
+
for item in data:
|
32 |
+
f.write(f"{item}\n")
|
33 |
+
|
34 |
+
|
35 |
+
def load_txt(filename):
|
36 |
+
"""Load a TXT file as a list."""
|
37 |
+
with open(filename, encoding="utf8") as f:
|
38 |
+
return [line.strip() for line in f]
|
39 |
+
|
40 |
+
|
41 |
+
def save_json(filename, data):
|
42 |
+
"""Save data as a JSON file."""
|
43 |
+
with open(filename, "w", encoding="utf8") as f:
|
44 |
+
json.dump(data, f)
|
45 |
+
|
46 |
+
|
47 |
+
def load_json(filename):
|
48 |
+
"""Load data from a JSON file."""
|
49 |
+
with open(filename, encoding="utf8") as f:
|
50 |
+
return json.load(f)
|
51 |
+
|
52 |
+
|
53 |
+
def save_csv(filename, data, header=""):
|
54 |
+
"""Save data as a CSV file."""
|
55 |
+
np.savetxt(
|
56 |
+
filename, data, fmt="%d", delimiter=",", header=header, comments=""
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def load_csv(filename, skiprows=1):
|
61 |
+
"""Load data from a CSV file."""
|
62 |
+
return np.loadtxt(filename, dtype=int, delimiter=",", skiprows=skiprows)
|
63 |
+
|
64 |
+
|
65 |
+
def load_csv_text(filename, headerless=True):
|
66 |
+
"""Read a CSV file into a list of dictionaries or lists."""
|
67 |
+
with open(filename) as f:
|
68 |
+
if headerless:
|
69 |
+
return [row for row in csv.reader(f)]
|
70 |
+
reader = csv.DictReader(f)
|
71 |
+
return [
|
72 |
+
{field: row[field] for field in reader.fieldnames}
|
73 |
+
for row in reader
|
74 |
+
]
|
75 |
+
|
76 |
+
|
77 |
+
def ignore_exceptions(func):
|
78 |
+
"""Decorator that ignores all errors and warnings."""
|
79 |
+
|
80 |
+
def inner(*args, **kwargs):
|
81 |
+
with warnings.catch_warnings():
|
82 |
+
warnings.simplefilter("ignore")
|
83 |
+
try:
|
84 |
+
return func(*args, **kwargs)
|
85 |
+
except Exception:
|
86 |
+
return None
|
87 |
+
|
88 |
+
return inner
|
89 |
+
|
90 |
+
|
91 |
+
def suppress_outputs(func):
|
92 |
+
"""Decorator that suppresses writing to stdout and stderr."""
|
93 |
+
|
94 |
+
def inner(*args, **kwargs):
|
95 |
+
devnull = open(os.devnull, "w")
|
96 |
+
with contextlib.redirect_stdout(devnull):
|
97 |
+
with contextlib.redirect_stderr(devnull):
|
98 |
+
return func(*args, **kwargs)
|
99 |
+
|
100 |
+
return inner
|
101 |
+
|
102 |
+
|
103 |
+
def resolve_paths(func):
|
104 |
+
"""Decorator that resolves all paths."""
|
105 |
+
|
106 |
+
def inner(*args, **kwargs):
|
107 |
+
parsed = func(*args, **kwargs)
|
108 |
+
for key in vars(parsed).keys():
|
109 |
+
if isinstance(getattr(parsed, key), pathlib.Path):
|
110 |
+
setattr(
|
111 |
+
parsed, key, getattr(parsed, key).expanduser().resolve()
|
112 |
+
)
|
113 |
+
return parsed
|
114 |
+
|
115 |
+
return inner
|
data_utils/v2a_utils/__init__.py
ADDED
File without changes
|
data_utils/v2a_utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-310.pyc
ADDED
Binary file (4.05 kB). View file
|
|
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-38.pyc
ADDED
Binary file (4.06 kB). View file
|
|
data_utils/v2a_utils/__pycache__/audio_text_dataset.cpython-39.pyc
ADDED
Binary file (4.09 kB). View file
|
|
data_utils/v2a_utils/__pycache__/audioset_224.cpython-39.pyc
ADDED
Binary file (6.64 kB). View file
|
|
data_utils/v2a_utils/__pycache__/audioset_video_224.cpython-39.pyc
ADDED
Binary file (5.84 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils.cpython-310.pyc
ADDED
Binary file (5.23 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils.cpython-39.pyc
ADDED
Binary file (6.59 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-310.pyc
ADDED
Binary file (5.94 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224.cpython-39.pyc
ADDED
Binary file (5.95 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-310.pyc
ADDED
Binary file (4.53 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-38.pyc
ADDED
Binary file (4.4 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224_audio.cpython-39.pyc
ADDED
Binary file (4.49 kB). View file
|
|
data_utils/v2a_utils/__pycache__/feature_utils_224_no_sync.cpython-39.pyc
ADDED
Binary file (4.75 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound.cpython-310.pyc
ADDED
Binary file (4.99 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound.cpython-39.pyc
ADDED
Binary file (5.18 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound_224.cpython-310.pyc
ADDED
Binary file (6.56 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound_224.cpython-39.pyc
ADDED
Binary file (6.5 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound_224_no_audio.cpython-310.pyc
ADDED
Binary file (5.64 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound_224_no_sync.cpython-39.pyc
ADDED
Binary file (5.14 kB). View file
|
|
data_utils/v2a_utils/__pycache__/vggsound_text.cpython-39.pyc
ADDED
Binary file (2.43 kB). View file
|
|
data_utils/v2a_utils/feature_utils_224.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal, Optional
|
2 |
+
import json
|
3 |
+
import open_clip
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange
|
8 |
+
from open_clip import create_model_from_pretrained
|
9 |
+
from torchvision.transforms import Normalize
|
10 |
+
from think_sound.models.factory import create_model_from_config
|
11 |
+
from think_sound.models.utils import load_ckpt_state_dict
|
12 |
+
from think_sound.training.utils import copy_state_dict
|
13 |
+
from transformers import AutoModel
|
14 |
+
from transformers import AutoProcessor
|
15 |
+
from transformers import T5EncoderModel, AutoTokenizer
|
16 |
+
import logging
|
17 |
+
from data_utils.ext.synchformer import Synchformer
|
18 |
+
|
19 |
+
log = logging.getLogger()
|
20 |
+
|
21 |
+
def patch_clip(clip_model):
|
22 |
+
# a hack to make it output last hidden states
|
23 |
+
# https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
|
24 |
+
def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
|
25 |
+
output_attentions: Optional[bool] = None,
|
26 |
+
output_hidden_states: Optional[bool] = None,
|
27 |
+
return_dict: Optional[bool] = None):
|
28 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
29 |
+
output_hidden_states = (
|
30 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
31 |
+
)
|
32 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
33 |
+
|
34 |
+
text_outputs = self.text_model(
|
35 |
+
input_ids=input_ids,
|
36 |
+
attention_mask=attention_mask,
|
37 |
+
position_ids=position_ids,
|
38 |
+
output_attentions=output_attentions,
|
39 |
+
output_hidden_states=output_hidden_states,
|
40 |
+
return_dict=return_dict,
|
41 |
+
)
|
42 |
+
last_hidden_state = text_outputs[0]
|
43 |
+
pooled_output = text_outputs[1]
|
44 |
+
text_features = self.text_projection(pooled_output)
|
45 |
+
|
46 |
+
return text_features, last_hidden_state
|
47 |
+
|
48 |
+
clip_model.get_text_features = new_get_text_features.__get__(clip_model)
|
49 |
+
return clip_model
|
50 |
+
|
51 |
+
|
52 |
+
class FeaturesUtils(nn.Module):
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
*,
|
57 |
+
vae_ckpt: Optional[str] = None,
|
58 |
+
vae_config: Optional[str] = None,
|
59 |
+
synchformer_ckpt: Optional[str] = None,
|
60 |
+
enable_conditions: bool = True,
|
61 |
+
need_vae_encoder: bool = True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
if enable_conditions:
|
66 |
+
self.clip_model = AutoModel.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
67 |
+
self.clip_model = patch_clip(self.clip_model)
|
68 |
+
self.t5_tokenizer = AutoTokenizer.from_pretrained("google/t5-v1_1-xl")
|
69 |
+
self.t5_model = T5EncoderModel.from_pretrained("google/t5-v1_1-xl")
|
70 |
+
self.clip_processor = AutoProcessor.from_pretrained("facebook/metaclip-h14-fullcc2.5b")
|
71 |
+
# self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
72 |
+
# std=[0.26862954, 0.26130258, 0.27577711])
|
73 |
+
self.synchformer = Synchformer()
|
74 |
+
self.synchformer.load_state_dict(
|
75 |
+
torch.load(synchformer_ckpt, weights_only=True, map_location='cpu'))
|
76 |
+
|
77 |
+
# self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
|
78 |
+
else:
|
79 |
+
self.clip_model = None
|
80 |
+
self.synchformer = None
|
81 |
+
self.tokenizer = None
|
82 |
+
|
83 |
+
if vae_ckpt is not None:
|
84 |
+
with open(vae_config) as f:
|
85 |
+
vae_config = json.load(f)
|
86 |
+
self.vae = create_model_from_config(vae_config)
|
87 |
+
print(f"Loading model checkpoint from {vae_ckpt}")
|
88 |
+
# Load checkpoint
|
89 |
+
copy_state_dict(self.vae, load_ckpt_state_dict(vae_ckpt,prefix='autoencoder.'))#,prefix='autoencoder.'
|
90 |
+
else:
|
91 |
+
self.tod = None
|
92 |
+
|
93 |
+
def compile(self):
|
94 |
+
if self.clip_model is not None:
|
95 |
+
self.clip_model.encode_image = torch.compile(self.clip_model.encode_image)
|
96 |
+
self.clip_model.encode_text = torch.compile(self.clip_model.encode_text)
|
97 |
+
if self.synchformer is not None:
|
98 |
+
self.synchformer = torch.compile(self.synchformer)
|
99 |
+
|
100 |
+
|
101 |
+
def train(self, mode: bool) -> None:
|
102 |
+
return super().train(False)
|
103 |
+
|
104 |
+
@torch.inference_mode()
|
105 |
+
def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
106 |
+
assert self.clip_model is not None, 'CLIP is not loaded'
|
107 |
+
# x: (B, T, C, H, W) H/W: 384
|
108 |
+
b, t, c, h, w = x.shape
|
109 |
+
|
110 |
+
assert c == 3 and h == 224 and w == 224
|
111 |
+
# x = self.clip_preprocess(x)
|
112 |
+
x = rearrange(x, 'b t c h w -> (b t) c h w')
|
113 |
+
outputs = []
|
114 |
+
if batch_size < 0:
|
115 |
+
batch_size = b * t
|
116 |
+
for i in range(0, b * t, batch_size):
|
117 |
+
outputs.append(self.clip_model.get_image_features(x[i:i + batch_size]))
|
118 |
+
x = torch.cat(outputs, dim=0)
|
119 |
+
# x = self.clip_model.encode_image(x, normalize=True)
|
120 |
+
x = rearrange(x, '(b t) d -> b t d', b=b)
|
121 |
+
return x
|
122 |
+
|
123 |
+
@torch.inference_mode()
|
124 |
+
def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor:
|
125 |
+
assert self.synchformer is not None, 'Synchformer is not loaded'
|
126 |
+
# x: (B, T, C, H, W) H/W: 384
|
127 |
+
b, t, c, h, w = x.shape
|
128 |
+
# import ipdb
|
129 |
+
# ipdb.set_trace()
|
130 |
+
assert c == 3 and h == 224 and w == 224
|
131 |
+
|
132 |
+
# partition the video
|
133 |
+
segment_size = 16
|
134 |
+
step_size = 8
|
135 |
+
num_segments = (t - segment_size) // step_size + 1
|
136 |
+
segments = []
|
137 |
+
for i in range(num_segments):
|
138 |
+
segments.append(x[:, i * step_size:i * step_size + segment_size])
|
139 |
+
x = torch.stack(segments, dim=1) # (B, S, T, C, H, W)
|
140 |
+
|
141 |
+
outputs = []
|
142 |
+
if batch_size < 0:
|
143 |
+
batch_size = b
|
144 |
+
x = rearrange(x, 'b s t c h w -> (b s) 1 t c h w')
|
145 |
+
for i in range(0, b * num_segments, batch_size):
|
146 |
+
outputs.append(self.synchformer(x[i:i + batch_size]))
|
147 |
+
x = torch.cat(outputs, dim=0)
|
148 |
+
x = rearrange(x, '(b s) 1 t d -> b (s t) d', b=b)
|
149 |
+
return x
|
150 |
+
|
151 |
+
@torch.inference_mode()
|
152 |
+
def encode_text(self, text: list[str]) -> torch.Tensor:
|
153 |
+
assert self.clip_model is not None, 'CLIP is not loaded'
|
154 |
+
# assert self.tokenizer is not None, 'Tokenizer is not loaded'
|
155 |
+
# x: (B, L)
|
156 |
+
tokens = self.clip_processor(text=text, truncation=True, max_length=77, padding="max_length",return_tensors="pt").to(self.device)
|
157 |
+
return self.clip_model.get_text_features(**tokens)
|
158 |
+
|
159 |
+
@torch.inference_mode()
|
160 |
+
def encode_t5_text(self, text: list[str]) -> torch.Tensor:
|
161 |
+
assert self.t5_model is not None, 'T5 model is not loaded'
|
162 |
+
assert self.t5_tokenizer is not None, 'T5 Tokenizer is not loaded'
|
163 |
+
# x: (B, L)
|
164 |
+
inputs = self.t5_tokenizer(text,
|
165 |
+
truncation=True,
|
166 |
+
max_length=77,
|
167 |
+
padding="max_length",
|
168 |
+
return_tensors="pt").to(self.device)
|
169 |
+
return self.t5_model(**inputs).last_hidden_state
|
170 |
+
|
171 |
+
@torch.inference_mode()
|
172 |
+
def encode_audio(self, x) -> torch.Tensor:
|
173 |
+
x = self.vae.encode(x)
|
174 |
+
return x
|
175 |
+
|
176 |
+
@property
|
177 |
+
def device(self):
|
178 |
+
return next(self.parameters()).device
|
179 |
+
|
180 |
+
@property
|
181 |
+
def dtype(self):
|
182 |
+
return next(self.parameters()).dtype
|