lorocksUMD commited on
Commit
f2b5019
·
verified ·
1 Parent(s): e1b5568

Delete DenseAV

Browse files
DenseAV/.gitignore DELETED
@@ -1,5 +0,0 @@
1
- # Created by .ignore support plugin (hsz.mobi)
2
- results/attention/*
3
- results/features/*
4
-
5
- .env
 
 
 
 
 
 
DenseAV/LICENSE DELETED
@@ -1,22 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) Mark Hamilton. All rights reserved.
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a
6
- copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be included
14
- in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
17
- OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/README.md DELETED
@@ -1,172 +0,0 @@
1
- # Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language
2
- ### CVPR 2024
3
-
4
-
5
- [![Website](https://img.shields.io/badge/DenseAV-%F0%9F%8C%90Website-purple?style=flat)](https://aka.ms/denseav) [![arXiv](https://img.shields.io/badge/arXiv-2406.05629-b31b1b.svg)](https://arxiv.org/abs/2406.05629) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mhamilton723/DenseAV/blob/main/demo.ipynb)
6
-
7
- [![Huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DenseAV-orange)](https://huggingface.co/spaces/mhamilton723/DenseAV)
8
-
9
- [//]: # ([![Huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Paper%20Page-orange)](https://huggingface.co/papers/2403.10516))
10
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/separating-the-chirp-from-the-chat-self/speech-prompted-semantic-segmentation-on)](https://paperswithcode.com/sota/speech-prompted-semantic-segmentation-on?p=separating-the-chirp-from-the-chat-self)
11
- [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/separating-the-chirp-from-the-chat-self/sound-prompted-semantic-segmentation-on)](https://paperswithcode.com/sota/sound-prompted-semantic-segmentation-on?p=separating-the-chirp-from-the-chat-self)
12
-
13
-
14
- [Mark Hamilton](https://mhamilton.net/),
15
- [Andrew Zisserman](https://www.robots.ox.ac.uk/~az/),
16
- [John R. Hershey](https://research.google/people/john-hershey/),
17
- [William T. Freeman](https://billf.mit.edu/about/bio)
18
-
19
- ![DenseAV Overview Graphic](https://mhamilton.net/images/hero_fig_black.jpg)
20
-
21
- **TL;DR**:Our model, DenseAV, learns the meaning of words and the location of sounds (visual grounding) without supervision or text.
22
-
23
- https://github.com/mhamilton723/DenseAV/assets/6456637/ba908ab5-9618-42f9-8d7a-30ecb009091f
24
-
25
-
26
- ## Contents
27
- <!--ts-->
28
- * [Install](#install)
29
- * [Model Zoo](#model-zoo)
30
- * [Getting Datasets](#getting-atasets)
31
- * [Evaluate Models](#evaluate-models)
32
- * [Train a Model](#train-model)
33
- * [Local Gradio Demo](#local-gradio-demo)
34
- * [Coming Soon](coming-soon)
35
- * [Citation](#citation)
36
- * [Contact](#contact)
37
- <!--te-->
38
-
39
- ## Install
40
-
41
- To use DenseAV locally clone the repository:
42
-
43
- ```shell script
44
- git clone https://github.com/mhamilton723/DenseAV.git
45
- cd DenseAV
46
- pip install -e .
47
- ```
48
-
49
-
50
- ## Model Zoo
51
-
52
- To see examples of pretrained model usage please see our [Collab notebook](https://colab.research.google.com/github/mhamilton723/DenseAV/blob/main/demo.ipynb). We currently supply the following pretrained models:
53
-
54
- | Model Name | Checkpoint | Torch Hub Repository | Torch Hub Name |
55
- |-------------------------------|----------------------------------------------------------------------------------------------------------------------------------|----------------------|--------------------|
56
- | Sound | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_sound.ckpt) | mhamilton723/DenseAV | sound |
57
- | Language | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_language.ckpt) | mhamilton723/DenseAV | language |
58
- | Sound + Language (Two Headed) | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_2head.ckpt) | mhamilton723/DenseAV | sound_and_language |
59
-
60
- For example, to load the model trained on both sound and language:
61
-
62
- ```python
63
- model = torch.hub.load("mhamilton723/DenseAV", 'sound_and_language')
64
- ```
65
-
66
- ### Load from HuggingFace
67
-
68
- ```python
69
- from denseav.train import LitAVAligner
70
-
71
- model1 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-sound")
72
- model2 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-language")
73
- model3 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-sound-language")
74
- ```
75
-
76
-
77
- ## Getting Datasets
78
-
79
- Our code assumes that all data lives in a common directory on your system, in these examples we use `/path/to/your/data`. Our code will often reference this directory as the `data_root`
80
-
81
- ### Speech and Sound Prompted ADE20K
82
-
83
- To download our new Speech and Sound prompted ADE20K Dataset:
84
-
85
- ```bash
86
- cd /path/to/your/data
87
- wget https://marhamilresearch4.blob.core.windows.net/denseav-public/datasets/ADE20KSoundPrompted.zip
88
- unzip ADE20KSoundPrompted.zip
89
- wget https://marhamilresearch4.blob.core.windows.net/denseav-public/datasets/ADE20KSpeechPrompted.zip
90
- unzip ADE20KSpeechPrompted.zip
91
- ```
92
-
93
- ### Places Audio
94
-
95
- First download the places audio dataset from its [original source](https://groups.csail.mit.edu/sls/downloads/placesaudio/downloads.cgi).
96
-
97
- To run the code the data will need to be processed to be of the form:
98
-
99
- ```
100
- [Instructions coming soon]
101
- ```
102
-
103
- ### Audioset
104
-
105
- Because of copyright issues we cannot make [Audioset](https://research.google.com/audioset/dataset/index.html) easily availible to download.
106
- First download this dataset through appropriate means. [This other project](https://github.com/ktonal/audioset-downloader) appears to make this simple.
107
-
108
- To run the code the data will need to be processed to be of the form:
109
-
110
- ```
111
- [Instructions coming soon]
112
- ```
113
-
114
-
115
- ## Evaluate Models
116
-
117
- To evaluate a trained model first clone the repository for
118
- [local development](#local-development). Then run
119
-
120
- ```shell
121
- cd featup
122
- python evaluate.py
123
- ```
124
-
125
- After evaluation, see the results in tensorboard's hparams tab.
126
-
127
- ```shell
128
- cd ../logs/evaluate
129
- tensorboard --logdir .
130
- ```
131
-
132
- Then visit [https://localhost:6006](https://localhost:6006) and click on hparams to browse results. We report "advanced" speech metrics and "basic" sound metrics in our paper.
133
-
134
-
135
- ## Train a Model
136
-
137
- ```shell
138
- cd denseav
139
- python train.py
140
- ```
141
-
142
- ## Local Gradio Demo
143
-
144
- To run our [HuggingFace Spaces hosted DenseAV demo](https://huggingface.co/spaces/mhamilton723/FeatUp) locally first install DenseAV for local development. Then run:
145
-
146
- ```shell
147
- python gradio_app.py
148
- ```
149
-
150
- Wait a few seconds for the demo to spin up, then navigate to [http://localhost:7860/](http://localhost:7860/) to view the demo.
151
-
152
-
153
- ## Coming Soon:
154
-
155
- - Bigger models!
156
-
157
- ## Citation
158
-
159
- ```
160
- @misc{hamilton2024separating,
161
- title={Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language},
162
- author={Mark Hamilton and Andrew Zisserman and John R. Hershey and William T. Freeman},
163
- year={2024},
164
- eprint={2406.05629},
165
- archivePrefix={arXiv},
166
- primaryClass={cs.CV}
167
- }
168
- ```
169
-
170
- ## Contact
171
-
172
- For feedback, questions, or press inquiries please contact [Mark Hamilton](mailto:[email protected])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/__init__.py DELETED
File without changes
DenseAV/demo.ipynb DELETED
The diff for this file is too large to render. See raw diff
 
DenseAV/denseav/__init__.py DELETED
File without changes
DenseAV/denseav/aggregators.py DELETED
@@ -1,517 +0,0 @@
1
- from abc import abstractmethod
2
-
3
- import math
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from tqdm import tqdm
8
-
9
- from denseav.constants import *
10
-
11
-
12
- @torch.jit.script
13
- def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int):
14
- mask = mask.to(x)
15
- return (x * mask).sum(dim, keepdim=True) / mask.sum(dim, keepdim=True).clamp_min(.001)
16
-
17
-
18
- @torch.jit.script
19
- def masked_max(x: torch.Tensor, mask: torch.Tensor, dim: int):
20
- mask = mask.to(torch.bool)
21
- eps = 1e7
22
- # eps = torch.finfo(x.dtype).max
23
- return (x - (~mask) * eps).max(dim, keepdim=True).values
24
-
25
-
26
- def masked_lse(x: torch.Tensor, mask: torch.Tensor, dim: int, temp):
27
- x = x.to(torch.float32)
28
- mask = mask.to(torch.float32)
29
- x_masked = (x - (1 - mask) * torch.finfo(x.dtype).max)
30
- return (torch.logsumexp(x_masked * temp, dim, keepdim=True) - torch.log(mask.sum(dim, keepdim=True))) / temp
31
-
32
-
33
- class BaseAggregator(torch.nn.Module):
34
-
35
- def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
36
- super().__init__()
37
-
38
- self.nonneg_sim = nonneg_sim
39
- self.mask_silence = mask_silence
40
- self.num_heads = num_heads
41
- self.head_agg = head_agg
42
- self.use_cls = use_cls
43
-
44
- @abstractmethod
45
- def _agg_sim(self, sim, mask):
46
- pass
47
-
48
- def prepare_sims(self, sim, mask, agg_sim, agg_heads):
49
- sim_size = sim.shape
50
- assert len(mask.shape) == 2
51
- assert len(sim_size) in {6, 7}, f"sim has wrong number of dimensions: {sim.shape}"
52
- pairwise = len(sim_size) == 6
53
-
54
- if self.mask_silence:
55
- mask = mask
56
- else:
57
- mask = torch.ones_like(mask)
58
-
59
- if self.nonneg_sim:
60
- sim = sim.clamp_min(0)
61
-
62
- if pairwise:
63
- head_dim = 1
64
- else:
65
- head_dim = 2
66
-
67
- if self.head_agg == "max_elementwise" and agg_heads:
68
- sim = sim.max(head_dim, keepdim=True).values
69
-
70
- if agg_sim:
71
- sim = self._agg_sim(sim, mask)
72
-
73
- if agg_heads:
74
- if self.head_agg == "sum" or self.head_agg == "max_elementwise":
75
- sim = sim.sum(head_dim)
76
- elif self.head_agg == "max":
77
- sim = sim.max(head_dim).values
78
- else:
79
- raise ValueError(f"Unknown head_agg: {self.head_agg}")
80
-
81
- return sim
82
-
83
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
84
- if agg_sim or agg_heads or raw:
85
- assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
86
-
87
- audio_feats = preds[AUDIO_FEATS]
88
- audio_mask = preds[AUDIO_MASK]
89
- image_feats = preds[IMAGE_FEATS]
90
-
91
- b1, c2, f, t1 = audio_feats.shape
92
- b2, t2 = audio_mask.shape
93
- d, c1, h, w = image_feats.shape
94
- assert b1 == b2 and c1 == c2 and t1 == t2
95
- assert c1 % self.num_heads == 0
96
- new_c = c1 // self.num_heads
97
- audio_feats = audio_feats.reshape(b1, self.num_heads, new_c, f, t1)
98
- image_feats = image_feats.reshape(d, self.num_heads, new_c, h, w)
99
- raw_sims = torch.einsum(
100
- "akcft,vkchw->avkhwft",
101
- audio_feats.to(torch.float32),
102
- image_feats.to(torch.float32))
103
-
104
- if self.use_cls:
105
- audio_cls = preds[AUDIO_CLS].reshape(b1, self.num_heads, new_c)
106
- image_cls = preds[IMAGE_CLS].reshape(d, self.num_heads, new_c)
107
- cls_sims = torch.einsum(
108
- "akc,vkc->avk",
109
- audio_cls.to(torch.float32),
110
- image_cls.to(torch.float32))
111
- raw_sims += cls_sims.reshape(b1, d, self.num_heads, 1, 1, 1, 1)
112
-
113
- if raw:
114
- return raw_sims
115
- else:
116
- return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
117
-
118
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
119
- if agg_sim or agg_heads or raw:
120
- assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
121
-
122
- audio_feats = preds[AUDIO_FEATS]
123
- audio_mask = preds[AUDIO_MASK]
124
- image_feats = preds[IMAGE_FEATS]
125
-
126
- a1, c1, f, t1 = audio_feats.shape
127
- a2, t2 = audio_mask.shape
128
-
129
- assert c1 % self.num_heads == 0
130
- new_c = c1 // self.num_heads
131
- audio_feats = audio_feats.reshape(a1, self.num_heads, new_c, f, t1)
132
-
133
- if len(image_feats.shape) == 5:
134
- print("Using similarity for video, should only be called during plotting")
135
- v, vt, c2, h, w = image_feats.shape
136
- image_feats = image_feats.reshape(v, vt, self.num_heads, new_c, h, w)
137
- raw_sims = torch.einsum(
138
- "bkcft,bskchw,bt->bskhwft",
139
- audio_feats.to(torch.float32),
140
- image_feats.to(torch.float32),
141
- audio_mask.to(torch.float32))
142
-
143
- if self.use_cls:
144
- audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
145
- image_cls = preds[IMAGE_CLS].reshape(v, vt, self.num_heads, new_c)
146
- cls_sims = torch.einsum(
147
- "bkc,bskc->bsk",
148
- audio_cls.to(torch.float32),
149
- image_cls.to(torch.float32))
150
- raw_sims += cls_sims.reshape(v, vt, self.num_heads, 1, 1, 1, 1)
151
-
152
-
153
- elif len(image_feats.shape) == 4:
154
- v, c2, h, w = image_feats.shape
155
- image_feats = image_feats.reshape(v, self.num_heads, new_c, h, w)
156
- raw_sims = torch.einsum(
157
- "bkcft,bkchw,bt->bkhwft",
158
- audio_feats.to(torch.float32),
159
- image_feats.to(torch.float32),
160
- audio_mask.to(torch.float32))
161
-
162
- if self.use_cls:
163
- audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
164
- image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
165
- cls_sims = torch.einsum(
166
- "bkc,bkc->bk",
167
- audio_cls.to(torch.float32),
168
- image_cls.to(torch.float32))
169
- raw_sims += cls_sims.reshape(v, self.num_heads, 1, 1, 1, 1)
170
- else:
171
- raise ValueError(f"Improper image shape: {image_feats.shape}")
172
-
173
- assert a1 == a2 and c2 == c2 and t1 == t2
174
-
175
- if raw:
176
- return raw_sims
177
- else:
178
- return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
179
-
180
- def forward(self, preds, agg_heads):
181
- return self._get_full_sims(
182
- preds, raw=False, agg_sim=True, agg_heads=agg_heads)
183
-
184
- def forward_batched(self, preds, agg_heads, batch_size):
185
- new_preds = {k: v for k, v in preds.items()}
186
- big_image_feats = new_preds.pop(IMAGE_FEATS)
187
- if self.use_cls:
188
- big_image_cls = new_preds.pop(IMAGE_CLS)
189
-
190
- n = big_image_feats.shape[0]
191
- n_steps = math.ceil(n / batch_size)
192
- outputs = []
193
- for step in tqdm(range(n_steps), "Calculating Sim", leave=False):
194
- new_preds[IMAGE_FEATS] = big_image_feats[step * batch_size:(step + 1) * batch_size].cuda()
195
- if self.use_cls:
196
- new_preds[IMAGE_CLS] = big_image_cls[step * batch_size:(step + 1) * batch_size].cuda()
197
-
198
- sim = self.forward(new_preds, agg_heads=agg_heads)
199
- outputs.append(sim.cpu())
200
- return torch.cat(outputs, dim=1)
201
-
202
-
203
- class ImageThenAudioAggregator(BaseAggregator):
204
-
205
- def __init__(self, image_agg_type, audio_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
206
- super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
207
- if image_agg_type == "max":
208
- self.image_agg = lambda x, dim: x.max(dim=dim, keepdim=True).values
209
- elif image_agg_type == "avg":
210
- self.image_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
211
- else:
212
- raise ValueError(f"Unknown image_agg_type {image_agg_type}")
213
-
214
- if audio_agg_type == "max":
215
- self.time_agg = masked_max
216
- elif audio_agg_type == "avg":
217
- self.time_agg = masked_mean
218
- else:
219
- raise ValueError(f"Unknown audio_agg_type {audio_agg_type}")
220
-
221
- self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
222
-
223
- def _agg_sim(self, sim, mask):
224
- sim_shape = sim.shape
225
- new_mask_shape = [1] * len(sim_shape)
226
- new_mask_shape[0] = sim_shape[0]
227
- new_mask_shape[-1] = sim_shape[-1]
228
- mask = mask.reshape(new_mask_shape)
229
- sim = self.image_agg(sim, -3)
230
- sim = self.image_agg(sim, -4)
231
- sim = self.freq_agg(sim, -2)
232
- sim = self.time_agg(sim, mask, -1)
233
- return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
234
-
235
-
236
- class PairedAggregator(BaseAggregator):
237
-
238
- def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
239
- super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
240
- self.image_agg_max = lambda x, dim: x.max(dim=dim, keepdim=True).values
241
- self.image_agg_mean = lambda x, dim: x.mean(dim=dim, keepdim=True)
242
-
243
- self.time_agg_max = masked_max
244
- self.time_agg_mean = masked_mean
245
-
246
- self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
247
-
248
- def _agg_sim(self, sim, mask):
249
- sim_shape = sim.shape
250
- new_mask_shape = [1] * len(sim_shape)
251
- new_mask_shape[0] = sim_shape[0]
252
- new_mask_shape[-1] = sim_shape[-1]
253
- mask = mask.reshape(new_mask_shape)
254
-
255
- sim_1 = self.image_agg_max(sim, -3)
256
- sim_1 = self.image_agg_max(sim_1, -4)
257
- sim_1 = self.freq_agg(sim_1, -2)
258
- sim_1 = self.time_agg_mean(sim_1, mask, -1)
259
-
260
- sim_2 = self.freq_agg(sim, -2)
261
- sim_2 = self.time_agg_max(sim_2, mask, -1)
262
- sim_2 = self.image_agg_mean(sim_2, -3)
263
- sim_2 = self.image_agg_mean(sim_2, -4)
264
-
265
- sim = 1 / 2 * (sim_1 + sim_2)
266
-
267
- return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
268
-
269
-
270
-
271
- class CAVMAEAggregator(BaseAggregator):
272
-
273
- def __init__(self, *args, **kwargs):
274
- super().__init__(False, False, 1, "sum", False)
275
-
276
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
277
- if agg_sim:
278
- audio_feats = preds[AUDIO_FEATS]
279
- image_feats = preds[IMAGE_FEATS]
280
- pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
281
- pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
282
- sims = torch.einsum(
283
- "bc,dc->bd",
284
- pool_audio_feats.to(torch.float32),
285
- pool_image_feats.to(torch.float32))
286
- if agg_heads:
287
- return sims
288
- else:
289
- return sims.unsqueeze(-1)
290
-
291
- else:
292
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
293
-
294
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
295
- if agg_sim:
296
- audio_feats = preds[AUDIO_FEATS]
297
- image_feats = preds[IMAGE_FEATS]
298
- pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
299
- pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
300
- sims = torch.einsum(
301
- "bc,bc->b",
302
- pool_audio_feats.to(torch.float32),
303
- pool_image_feats.to(torch.float32))
304
- if agg_heads:
305
- return sims
306
- else:
307
- return sims.unsqueeze(-1)
308
-
309
- else:
310
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
311
-
312
-
313
- class ImageBindAggregator(BaseAggregator):
314
-
315
- def __init__(self, num_heads, *args, **kwargs):
316
- super().__init__(False, False, num_heads, "sum", False)
317
-
318
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
319
- if agg_sim:
320
- sims = torch.einsum(
321
- "bc,dc->bd",
322
- preds[AUDIO_CLS].to(torch.float32),
323
- preds[IMAGE_CLS].to(torch.float32))
324
- if agg_heads:
325
- return sims
326
- else:
327
- sims = sims.unsqueeze(-1)
328
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
329
-
330
-
331
- else:
332
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
333
-
334
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
335
- if agg_sim:
336
- sims = torch.einsum(
337
- "bc,dc->b",
338
- preds[AUDIO_CLS].to(torch.float32),
339
- preds[IMAGE_CLS].to(torch.float32))
340
- if agg_heads:
341
- return sims
342
- else:
343
- sims = sims.unsqueeze(-1)
344
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
345
-
346
- else:
347
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
348
-
349
- def forward_batched(self, preds, agg_heads, batch_size):
350
- return self.forward(preds, agg_heads)
351
-
352
-
353
- class SimPool(nn.Module):
354
- def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
355
- super().__init__()
356
- self.num_heads = num_heads
357
- head_dim = dim // num_heads
358
- self.scale = qk_scale or head_dim ** -0.5
359
-
360
- self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
361
-
362
- self.wq = nn.Linear(dim, dim, bias=qkv_bias)
363
- self.wk = nn.Linear(dim, dim, bias=qkv_bias)
364
-
365
- if gamma is not None:
366
- self.gamma = torch.tensor([gamma])
367
- if use_beta:
368
- self.beta = nn.Parameter(torch.tensor([0.0]))
369
- self.eps = torch.tensor([1e-6])
370
-
371
- self.gamma = gamma
372
- self.use_beta = use_beta
373
-
374
- def prepare_input(self, x):
375
- if len(x.shape) == 3: # Transformer
376
- # Input tensor dimensions:
377
- # x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
378
- B, N, d = x.shape
379
- gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
380
- gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
381
- return gap_cls, x
382
- if len(x.shape) == 4: # CNN
383
- # Input tensor dimensions:
384
- # x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
385
- B, d, H, W = x.shape
386
- gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
387
- x = x.reshape(B, d, H * W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
388
- gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
389
- return gap_cls, x
390
- else:
391
- raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
392
-
393
- def forward(self, x):
394
- self.eps = self.eps.to(x.device)
395
- # Prepare input tensor and perform GAP as initialization
396
- gap_cls, x = self.prepare_input(x)
397
-
398
- # Prepare queries (q), keys (k), and values (v)
399
- q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
400
-
401
- # Extract dimensions after normalization
402
- Bq, Nq, dq = q.shape
403
- Bk, Nk, dk = k.shape
404
- Bv, Nv, dv = v.shape
405
-
406
- # Check dimension consistency across batches and channels
407
- assert Bq == Bk == Bv
408
- assert dq == dk == dv
409
-
410
- # Apply linear transformation for queries and keys then reshape
411
- qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1,
412
- 3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
413
- kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1,
414
- 3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
415
-
416
- vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1,
417
- 3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
418
-
419
- # Compute attention scores
420
- attn = (qq @ kk.transpose(-2, -1)) * self.scale
421
- # Apply softmax for normalization
422
- attn = attn.softmax(dim=-1)
423
-
424
- # If gamma scaling is used
425
- if self.gamma is not None:
426
- # Apply gamma scaling on values and compute the weighted sum using attention scores
427
- x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma),
428
- 1 / self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
429
- # If use_beta, add a learnable translation
430
- if self.use_beta:
431
- x = x + self.beta
432
- else:
433
- # Compute the weighted sum using attention scores
434
- x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
435
-
436
- return x.squeeze()
437
-
438
-
439
-
440
- class SimPoolAggregator(BaseAggregator):
441
-
442
- def __init__(self, num_heads, dim, *args, **kwargs):
443
- super().__init__(False, False, num_heads, "sum", False)
444
- self.pool = SimPool(dim, gamma=1.25)
445
-
446
- def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
447
- if agg_sim:
448
- device = self.pool.wq.weight.data.device
449
- pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
450
- pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
451
-
452
- sims = torch.einsum(
453
- "bc,dc->bd",
454
- pooled_audio,
455
- pooled_image)
456
- if agg_heads:
457
- return sims
458
- else:
459
- sims = sims.unsqueeze(-1)
460
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
461
-
462
-
463
- else:
464
- return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
465
-
466
- def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
467
- if agg_sim:
468
- device = self.pool.wq.weight.data.device
469
- pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
470
- pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
471
-
472
- sims = torch.einsum(
473
- "bc,dc->b",
474
- pooled_audio,
475
- pooled_image)
476
- if agg_heads:
477
- return sims
478
- else:
479
- sims = sims.unsqueeze(-1)
480
- return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
481
-
482
- else:
483
- return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
484
-
485
- def forward_batched(self, preds, agg_heads, batch_size):
486
- return self.forward(preds, agg_heads)
487
-
488
-
489
-
490
- def get_aggregator(sim_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls, dim):
491
- shared_args = dict(
492
- nonneg_sim=nonneg_sim,
493
- mask_silence=mask_silence,
494
- num_heads=num_heads,
495
- head_agg=head_agg,
496
- use_cls=use_cls,
497
- )
498
-
499
- if sim_agg_type == "paired":
500
- agg1 = PairedAggregator(**shared_args)
501
- elif sim_agg_type == "misa":
502
- agg1 = ImageThenAudioAggregator("max", "avg", **shared_args)
503
- elif sim_agg_type == "mima":
504
- agg1 = ImageThenAudioAggregator("max", "max", **shared_args)
505
- elif sim_agg_type == "sisa":
506
- agg1 = ImageThenAudioAggregator("avg", "avg", **shared_args)
507
- elif sim_agg_type == "cavmae":
508
- agg1 = CAVMAEAggregator()
509
- elif sim_agg_type == "imagebind":
510
- agg1 = ImageBindAggregator(num_heads=shared_args["num_heads"])
511
- elif sim_agg_type == "simpool":
512
- agg1 = SimPoolAggregator(num_heads=shared_args["num_heads"], dim=dim)
513
- else:
514
- raise ValueError(f"Unknown loss_type {sim_agg_type}")
515
-
516
- return agg1
517
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/aligners.py DELETED
@@ -1,300 +0,0 @@
1
- from functools import partial
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from torch.nn import ModuleList
6
-
7
- from denseav.featurizers.DINO import Block
8
-
9
-
10
- class ChannelNorm(torch.nn.Module):
11
-
12
- def __init__(self, dim, *args, **kwargs):
13
- super().__init__(*args, **kwargs)
14
- self.norm = torch.nn.LayerNorm(dim, eps=1e-4)
15
-
16
- def forward_spatial(self, x):
17
- return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
18
-
19
- def forward(self, x, cls):
20
- return self.forward_spatial(x), self.forward_cls(cls)
21
-
22
- def forward_cls(self, cls):
23
- if cls is not None:
24
- return self.norm(cls)
25
- else:
26
- return None
27
-
28
-
29
- def id_conv(dim, strength=.9):
30
- conv = torch.nn.Conv2d(dim, dim, 1, padding="same")
31
- start_w = conv.weight.data
32
- conv.weight.data = torch.nn.Parameter(
33
- torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength))
34
- conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength))
35
- return conv
36
-
37
-
38
- class LinearAligner(torch.nn.Module):
39
- def __init__(self, in_dim, out_dim, use_norm=True):
40
- super().__init__()
41
- self.in_dim = in_dim
42
- self.out_dim = out_dim
43
- if use_norm:
44
- self.norm = ChannelNorm(in_dim)
45
- else:
46
- self.norm = Identity2()
47
-
48
- if in_dim == out_dim:
49
- self.layer = id_conv(in_dim, 0)
50
- else:
51
- self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1)
52
-
53
- self.cls_layer = torch.nn.Linear(in_dim, out_dim)
54
-
55
- def forward(self, spatial, cls):
56
- norm_spatial, norm_cls = self.norm(spatial, cls)
57
-
58
- if cls is not None:
59
- aligned_cls = self.cls_layer(cls)
60
- else:
61
- aligned_cls = None
62
-
63
- return self.layer(norm_spatial), aligned_cls
64
-
65
- class IdLinearAligner(torch.nn.Module):
66
- def __init__(self, in_dim, out_dim):
67
- super().__init__()
68
- self.in_dim = in_dim
69
- self.out_dim = out_dim
70
- assert self.out_dim == self.in_dim
71
- self.layer = id_conv(in_dim, 1.0)
72
- def forward(self, spatial, cls):
73
- return self.layer(spatial), cls
74
-
75
-
76
- class FrequencyAvg(torch.nn.Module):
77
- def __init__(self):
78
- super().__init__()
79
-
80
- def forward(self, spatial, cls):
81
- return spatial.mean(2, keepdim=True), cls
82
-
83
-
84
- class LearnedTimePool(torch.nn.Module):
85
- def __init__(self, dim, width, maxpool):
86
- super().__init__()
87
- self.dim = dim
88
- self.width = width
89
- self.norm = ChannelNorm(dim)
90
- if maxpool:
91
- self.layer = torch.nn.Sequential(
92
- torch.nn.Conv2d(dim, dim, kernel_size=width, stride=1, padding="same"),
93
- torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
94
- )
95
- else:
96
- self.layer = torch.nn.Conv2d(dim, dim, kernel_size=(1, width), stride=(1, width))
97
-
98
- def forward(self, spatial, cls):
99
- norm_spatial, norm_cls = self.norm(spatial, cls)
100
- return self.layer(norm_spatial), norm_cls
101
-
102
-
103
- class LearnedTimePool2(torch.nn.Module):
104
- def __init__(self, in_dim, out_dim, width, maxpool, use_cls_layer):
105
- super().__init__()
106
- self.in_dim = in_dim
107
- self.out_dim = out_dim
108
- self.width = width
109
-
110
- if maxpool:
111
- self.layer = torch.nn.Sequential(
112
- torch.nn.Conv2d(in_dim, out_dim, kernel_size=width, stride=1, padding="same"),
113
- torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
114
- )
115
- else:
116
- self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=(1, width), stride=(1, width))
117
-
118
- self.use_cls_layer = use_cls_layer
119
- if use_cls_layer:
120
- self.cls_layer = torch.nn.Linear(in_dim, out_dim)
121
-
122
- def forward(self, spatial, cls):
123
-
124
- if cls is not None:
125
- if self.use_cls_layer:
126
- aligned_cls = self.cls_layer(cls)
127
- else:
128
- aligned_cls = cls
129
- else:
130
- aligned_cls = None
131
-
132
- return self.layer(spatial), aligned_cls
133
-
134
-
135
- class Sequential2(torch.nn.Module):
136
-
137
- def __init__(self, *modules):
138
- super().__init__()
139
- self.mod_list = ModuleList(modules)
140
-
141
- def forward(self, x, y):
142
- results = (x, y)
143
- for m in self.mod_list:
144
- results = m(*results)
145
- return results
146
-
147
-
148
- class ProgressiveGrowing(torch.nn.Module):
149
-
150
- def __init__(self, stages, phase_lengths):
151
- super().__init__()
152
- self.stages = torch.nn.ModuleList(stages)
153
- self.phase_lengths = torch.tensor(phase_lengths)
154
- assert len(self.phase_lengths) + 1 == len(self.stages)
155
- self.phase_boundaries = self.phase_lengths.cumsum(0)
156
- self.register_buffer('phase', torch.tensor([1]))
157
-
158
- def maybe_change_phase(self, global_step):
159
- needed_phase = (global_step >= self.phase_boundaries).to(torch.int64).sum().item() + 1
160
- if needed_phase != self.phase.item():
161
- print(f"Changing aligner phase to {needed_phase}")
162
- self.phase.copy_(torch.tensor([needed_phase]).to(self.phase.device))
163
- return True
164
- else:
165
- return False
166
-
167
- def parameters(self, recurse: bool = True):
168
- phase = self.phase.item()
169
- used_stages = self.stages[:phase]
170
- print(f"Progressive Growing at stage {phase}")
171
- all_params = []
172
- for stage in used_stages:
173
- all_params.extend(stage.parameters(recurse))
174
- return iter(all_params)
175
-
176
- def forward(self, spatial, cls):
177
- pipeline = Sequential2(*self.stages[:self.phase.item()])
178
- return pipeline(spatial, cls)
179
-
180
-
181
- class Identity2(torch.nn.Module):
182
-
183
- def __init__(self):
184
- super().__init__()
185
-
186
- def forward(self, x, y):
187
- return x, y
188
-
189
-
190
- class SelfAttentionAligner(torch.nn.Module):
191
-
192
- def __init__(self, dim):
193
- super().__init__()
194
- self.dim = dim
195
-
196
- self.num_heads = 6
197
- if dim % self.num_heads != 0:
198
- self.padding = self.num_heads - (dim % self.num_heads)
199
- else:
200
- self.padding = 0
201
-
202
- self.block = Block(
203
- dim + self.padding,
204
- num_heads=self.num_heads,
205
- mlp_ratio=4,
206
- qkv_bias=True,
207
- qk_scale=None,
208
- drop=0.0,
209
- attn_drop=0.0,
210
- drop_path=0.0,
211
- norm_layer=partial(torch.nn.LayerNorm, eps=1e-4))
212
-
213
- def forward(self, spatial, cls):
214
- padded_feats = F.pad(spatial, [0, 0, 0, 0, self.padding, 0])
215
-
216
- B, C, H, W = padded_feats.shape
217
- proj_feats = padded_feats.reshape(B, C, H * W).permute(0, 2, 1)
218
-
219
- if cls is not None:
220
- assert len(cls.shape) == 2
221
- padded_cls = F.pad(cls, [self.padding, 0])
222
- proj_feats = torch.cat([padded_cls.unsqueeze(1), proj_feats], dim=1)
223
-
224
- aligned_feat, attn, qkv = self.block(proj_feats, return_qkv=True)
225
-
226
- if cls is not None:
227
- aligned_cls = aligned_feat[:, 0, :]
228
- aligned_spatial = aligned_feat[:, 1:, :]
229
- else:
230
- aligned_cls = None
231
- aligned_spatial = aligned_feat
232
-
233
- aligned_spatial = aligned_spatial.reshape(B, H, W, self.dim + self.padding).permute(0, 3, 1, 2)
234
-
235
- aligned_spatial = aligned_spatial[:, self.padding:, :, :]
236
- if aligned_cls is not None:
237
- aligned_cls = aligned_cls[:, self.padding:]
238
-
239
- return aligned_spatial, aligned_cls
240
-
241
-
242
- def get_aligner(aligner_type, in_dim, out_dim, **kwargs):
243
- if aligner_type is None:
244
- return Identity2()
245
-
246
- if "prog" in aligner_type:
247
- phase_length = kwargs["phase_length"]
248
-
249
- if aligner_type == "image_linear":
250
- return LinearAligner(in_dim, out_dim)
251
- elif aligner_type == "image_idlinear":
252
- return IdLinearAligner(in_dim, out_dim)
253
- elif aligner_type == "image_linear_no_norm":
254
- return LinearAligner(in_dim, out_dim, use_norm=False)
255
- elif aligner_type == "image_id":
256
- return Identity2()
257
- elif aligner_type == "image_norm":
258
- return ChannelNorm(in_dim)
259
- elif aligner_type == "audio_linear":
260
- return Sequential2(
261
- LinearAligner(in_dim, out_dim),
262
- FrequencyAvg())
263
- elif aligner_type == "audio_sa":
264
- return Sequential2(
265
- LinearAligner(in_dim, out_dim),
266
- FrequencyAvg(),
267
- SelfAttentionAligner(out_dim)
268
- )
269
- elif aligner_type == "audio_sa_sa":
270
- return Sequential2(
271
- FrequencyAvg(),
272
- LinearAligner(in_dim, out_dim),
273
- SelfAttentionAligner(out_dim),
274
- SelfAttentionAligner(out_dim)
275
- )
276
- elif aligner_type == "audio_3_3_pool":
277
- return Sequential2(
278
- LinearAligner(in_dim, out_dim),
279
- FrequencyAvg(),
280
- LearnedTimePool(out_dim, 3, False),
281
- LearnedTimePool(out_dim, 3, False),
282
- )
283
- elif aligner_type == "audio_sa_3_3_pool":
284
- return Sequential2(
285
- LinearAligner(in_dim, out_dim),
286
- FrequencyAvg(),
287
- LearnedTimePool(out_dim, 3, False),
288
- LearnedTimePool(out_dim, 3, False),
289
- SelfAttentionAligner(out_dim)
290
- )
291
- elif aligner_type == "audio_sa_3_3_pool_2":
292
- return Sequential2(
293
- FrequencyAvg(),
294
- ChannelNorm(in_dim),
295
- LearnedTimePool2(in_dim, out_dim, 3, False, True),
296
- LearnedTimePool2(out_dim, out_dim, 3, False, False),
297
- SelfAttentionAligner(out_dim)
298
- )
299
- else:
300
- raise ValueError(f"Unknown aligner type {aligner_type}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/configs/av_align.yaml DELETED
@@ -1,125 +0,0 @@
1
- # Model args
2
-
3
- code_dim: 384
4
- image_model_type: "dino8"
5
- image_model_token_type: "token"
6
- image_aligner_type: "image_linear"
7
- image_pool_width: 2
8
-
9
- audio_model_type: "hubert"
10
- audio_aligner_type: "audio_sa_3_3_pool_2"
11
- audio_pool_width: 1
12
-
13
- learn_audio_cls: True
14
-
15
- #code_dim: 1024
16
- #image_model_type: "imagebind"
17
- #image_model_token_type: "token"
18
- #image_aligner_type: "image_linear"
19
- #image_pool_width: 1
20
- #
21
- #audio_model_type: "imagebind"
22
- #audio_aligner_type: "audio_sa"
23
- #audio_pool_width: 1
24
- #
25
- #learn_audio_cls: False
26
-
27
- audio_lora: False
28
- audio_lora_rank: 8
29
- image_lora: True
30
- image_lora_rank: 8
31
-
32
-
33
- spatial_dropout: 0.0
34
- channel_dropout: 0.0
35
-
36
- quad_mixup: 0.1
37
- bg_mixup: 0.0
38
- patch_mixup: 0.0
39
- mixup_weight: 0.1
40
-
41
- sim_agg_type: "misa"
42
- sim_agg_heads: 1
43
- sim_use_cls: False
44
-
45
- cal_init: 1.0
46
- cal_balance_weight: 0.1
47
- nonneg_sim: False
48
- nonneg_pressure: 0.01
49
- silence_l1: 0.01
50
- silence_l2: 0.0
51
- tv_weight: 0.01
52
- specialization_weight: 0.05
53
- head_agg: "max_elementwise"
54
- disentangle_weight: 0.0
55
-
56
- norm_vectors: False
57
-
58
- neg_audio: true
59
- neg_audio_weight: 0.01
60
-
61
-
62
- pretrain_steps: 3000
63
- pretrain_lr: .5e-4
64
-
65
- # Loss args
66
- lr: .5e-4
67
- lr_warmup: 1000
68
-
69
- #lr_warmup: 100
70
-
71
- lr_schedule: ~
72
- lr_cycle_length: 50000
73
-
74
- optimizer: "adam"
75
- gradient_clipping: 10.0
76
- adaptive_clipping: True
77
- gather_tensors: True
78
- loss_type: "nce"
79
- loss_leak: 0.0
80
- loss_margin: 0.0
81
- mask_silence: true
82
- extra_audio_masking: true
83
- max_steps: 1000001
84
-
85
- finetune_image_model: False
86
- finetune_audio_model: True
87
-
88
- # Checkpointing args
89
- load_strict: true
90
- starting_weights: ~
91
- auto_resume: false
92
- grouping_name: "foo"
93
- resume_prefix: "imagebind_exp2"
94
-
95
- # Data Args
96
- #dataset_name: "sample-audio"
97
- dataset_name: "places-audio"
98
- #dataset_name: "mixed"
99
- #dataset_name: "audio-set-full"
100
- use_extra_val_sets: true
101
- batch_size: 10
102
- load_size: 224
103
- image_aug: true
104
- audio_aug: false
105
-
106
- audio_level: false
107
-
108
- memory_buffer_size: 0
109
-
110
- val_check_interval: 10000 #0
111
- use_cached_embs: false
112
- num_workers: 12
113
- num_gpus: 4
114
- num_sanity_val_steps: 0 #-1
115
- seed: 0
116
-
117
- # Env args
118
- output_root: '../'
119
- pytorch_data_dir: '/pytorch-data/'
120
- submitting_to_aml: false
121
-
122
- hydra:
123
- run:
124
- dir: "."
125
- output_subdir: ~
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/constants.py DELETED
@@ -1,12 +0,0 @@
1
-
2
- IMAGE_INPUT = "frames"
3
- IMAGE_FEATS = "image_feats"
4
- IMAGE_CLS = "image_cls"
5
- IMAGE_MASK = "image_masks"
6
-
7
- AUDIO_FEATS = "audio_feats"
8
- AUDIO_CLS = "audio_cls"
9
- AUDIO_MASK = "audio_mask"
10
- AUDIO_POS_MASK = "audio_pos_mask"
11
-
12
- DATA_SOURCE = "source"
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/data/AVDatasets.py DELETED
@@ -1,1249 +0,0 @@
1
- import glob
2
- import os
3
- from abc import ABC, abstractmethod
4
- from glob import glob
5
- from os.path import join
6
- from pathlib import Path
7
- from typing import List, Set
8
-
9
- import audioread
10
- import numpy as np
11
- import pandas as pd
12
- import pytorch_lightning as pl
13
- import torch
14
- import torch.nn.functional as F
15
- import torchaudio
16
- import torchvision.transforms as T
17
- from PIL import Image
18
- from torch.utils.data import Dataset, DataLoader, default_collate, Subset, ConcatDataset
19
- from tqdm import tqdm
20
-
21
- from denseav.constants import AUDIO_MASK, AUDIO_POS_MASK, IMAGE_MASK, IMAGE_INPUT
22
- from denseav.data.make_tarballs import untar_all
23
- from denseav.shared import norm, prep_waveform
24
-
25
-
26
- def sample_choice(choices, probs):
27
- # Check that probabilities sum to 1 and are non-negative
28
- assert sum(probs) == 1, "Probabilities must sum to 1"
29
- assert all(p >= 0 for p in probs), "Probabilities cannot be negative"
30
-
31
- # Convert probs to a tensor
32
- probs_tensor = torch.tensor(probs)
33
-
34
- # Sample a choice according to the probabilities
35
- index = torch.multinomial(probs_tensor, 1).item()
36
-
37
- # Return the sampled choice
38
- return choices[index]
39
-
40
-
41
- def grid_frames(frames):
42
- top_row = torch.cat([frames[0], frames[1]], dim=2)
43
- bottom_row = torch.cat([frames[2], frames[3]], dim=2)
44
- return torch.cat([top_row, bottom_row], dim=3)
45
-
46
-
47
- def create_mixed_image(pos_frame, neg_frame, patch_size):
48
- # Step 1: Check that patch_size evenly divides the image dimensions
49
- b, c, h, w = pos_frame.shape
50
- assert h % patch_size == 0 and w % patch_size == 0, "Patch size must evenly divide image dimensions"
51
-
52
- # Step 2: Create a random binary mask with the same number of patches as the image
53
- mask = torch.randint(0, 2, (b, 1, h // patch_size, w // patch_size))
54
-
55
- # Step 3: Create a new image using patches from pos_frame and neg_frame according to the mask
56
- # Upscale the mask to the size of the image
57
- mask_upscaled = F.interpolate(mask.to(torch.float32), scale_factor=patch_size)
58
-
59
- # Use the mask to create a mixed frame
60
- mixed_frame = mask_upscaled * pos_frame + (1 - mask_upscaled) * neg_frame
61
-
62
- return mixed_frame, mask_upscaled
63
-
64
-
65
- class AVDataset(ABC, Dataset):
66
-
67
- @abstractmethod
68
- def _dataset_folder(self) -> str:
69
- pass
70
-
71
- @abstractmethod
72
- def _load_info(self, split) -> pd.DataFrame:
73
- """
74
- This function should return a dataframe with at least a column "id"
75
- @return:
76
- """
77
- pass
78
-
79
- @abstractmethod
80
- def _missing_threshold(self) -> float:
81
- pass
82
-
83
- @abstractmethod
84
- def default_target_length(self) -> int:
85
- pass
86
-
87
- def target_length(self):
88
- if self.override_target_length is not None:
89
- return self.override_target_length
90
- else:
91
- return self.default_target_length()
92
-
93
- def _frame_root(self) -> str:
94
- return join(self.root, "frames", self.split)
95
-
96
- def _video_root(self) -> str:
97
- return join(self.root, "videos", self.split)
98
-
99
- def _audio_root(self) -> str:
100
- return join(self.root, "audio", self.split)
101
-
102
- def _semseg_root(self) -> str:
103
- return join(self.root, "annotations", self.split)
104
-
105
- def _embed_root(self) -> str:
106
- return join(self.root, "embedding", self.audio_embed_model, self.split)
107
-
108
- def _label_root(self) -> str:
109
- return join(self.root, "pseudo-labels")
110
-
111
- def _hn_root(self) -> str:
112
- return join(self.root, "hard_negatives")
113
-
114
- def _all_video_files(self) -> Set[str]:
115
- return set(str(p) for p in Path(join(self._video_root())).rglob('*'))
116
-
117
- def _all_frame_files(self) -> Set[str]:
118
- return set(str(p) for p in Path(join(self._frame_root())).rglob('*'))
119
-
120
- def _all_audio_files(self) -> Set[str]:
121
- return set(str(p) for p in Path(join(self._audio_root())).rglob('*'))
122
-
123
- def _all_embed_files(self) -> Set[str]:
124
- return set(str(p) for p in Path(join(self._embed_root())).rglob('*'))
125
-
126
- def _get_frame_files(self, row) -> List[str]:
127
- return [self._frame_root() + "/" + row["id"] + f"_{i}.jpg" for i in range(self._expected_num_frames())]
128
-
129
- def _get_semseg_file(self, row) -> str:
130
- raise NotImplementedError("Class has not implemented _get_semseg_files")
131
-
132
- def _get_audio_file(self, row) -> str:
133
- return self._audio_root() + "/" + row["id"] + ".mp3"
134
-
135
- def _get_video_file(self, row) -> str:
136
- return self._video_root() + "/" + row["id"] + ".mp4"
137
-
138
- def _get_embed_file(self, row) -> str:
139
- return self._embed_root() + "/" + row["id"] + ".npz"
140
-
141
- def _add_files_to_metadata(self, df) -> pd.DataFrame:
142
- tqdm.pandas()
143
-
144
- if self.use_audio_embed:
145
- df["embed_file"] = df.progress_apply(self._get_embed_file, axis=1)
146
-
147
- if self.use_audio or self.use_spec:
148
- df["audio_file"] = df.progress_apply(self._get_audio_file, axis=1)
149
-
150
- if self.use_frames:
151
- df["frame_files"] = df.progress_apply(self._get_frame_files, axis=1)
152
-
153
- if self.use_semseg:
154
- df["semseg_file"] = df.progress_apply(self._get_semseg_file, axis=1)
155
-
156
- df = self._filter_valid_metadata(df)
157
-
158
- if self.use_hn:
159
- loaded = np.load(join(self._hn_root(), "original", f"{self.split}_hard_negatives.npz"))
160
- df["hn0"] = [t for t in torch.tensor(loaded["indices_0"])]
161
- df["hn1"] = [t for t in torch.tensor(loaded["indices_1"])]
162
-
163
- return df
164
-
165
- def _split_name(self, split):
166
- return split
167
-
168
- def _filter_valid_metadata(self, df: pd.DataFrame) -> pd.DataFrame:
169
-
170
- print("MY_DIR ", list(glob(join(self.root, "*"))))
171
- if self.use_audio_embed:
172
- missing_embed_files = set(df['embed_file']) - self.all_embed_files
173
- valid_audio = ~df['embed_file'].isin(missing_embed_files)
174
- print("ALL EMBED ", len(self.all_embed_files))
175
- elif self.use_audio or self.use_spec:
176
- missing_audio_files = set(df['audio_file']) - self.all_audio_files
177
- valid_audio = ~df['audio_file'].isin(missing_audio_files)
178
- print("ALL AUDIO ", len(self.all_audio_files))
179
-
180
- if self.use_frames:
181
- missing_frame_files = set(
182
- item for sublist in df['frame_files'].tolist() for item in sublist) - self.all_frame_files
183
- valid_frames = df['frame_files'].apply(lambda x: not any(file in missing_frame_files for file in x))
184
- print("ALL FRAMES ", len(self.all_frame_files))
185
- df["is_valid"] = valid_audio & valid_frames
186
- else:
187
- df["is_valid"] = valid_audio
188
-
189
- percent_missing = (1 - (df["is_valid"].sum() / len(df)))
190
-
191
- assert percent_missing <= self._missing_threshold(), \
192
- f"Too many missing files: %{round(percent_missing * 100.0, 2)}"
193
- assert len(df) > 0, "No files found"
194
- return df[df["is_valid"]]
195
-
196
- def __init__(
197
- self,
198
- root: str,
199
- split: str = "train",
200
- use_frames=False,
201
- frame_transform=None,
202
- use_audio=False,
203
- use_spec=False,
204
- use_audio_embed=False,
205
- use_hn=False,
206
- use_caption=False,
207
- use_semseg=False,
208
- neg_audio=False,
209
- use_davenet_spec=False,
210
- use_fnac_spec=False,
211
- n_label_frames=196,
212
- label_transform=None,
213
- audio_embed_model="hubert",
214
- n_frames=1,
215
- audio_transform=None,
216
- audio_aug=False,
217
- spec_transform=None,
218
- spec_mel_bins=128,
219
- spec_mean=-6.6268077,
220
- spec_std=5.358466,
221
- sample_rate=16000,
222
- override_target_length=None,
223
- use_tags=False,
224
- extra_audio_masking=False,
225
- audio_level=False,
226
- quad_mixup=0.0,
227
- bg_mixup=0.0,
228
- patch_mixup=0.0,
229
- patch_size=8,
230
- ):
231
- super(AVDataset).__init__()
232
- self.pytorch_data_dir = root
233
- self.split = self._split_name(split)
234
- self.root = join(root, self._dataset_folder())
235
- self.use_frames = use_frames
236
- self.frame_transform = frame_transform
237
- self.use_audio = use_audio
238
- self.use_spec = use_spec
239
- self.use_audio_embed = use_audio_embed
240
- self.use_davenet_spec = use_davenet_spec
241
- self.use_fnac_spec = use_fnac_spec
242
- self.use_hn = use_hn
243
- self.use_caption = use_caption
244
- self.label_transform = label_transform
245
- self.audio_embed_model = audio_embed_model
246
- self.audio_aug = audio_aug
247
- self.n_frames = n_frames
248
- self.audio_transform = audio_transform
249
- self.spec_transform = spec_transform
250
- self.spec_mel_bins = spec_mel_bins
251
- self.spec_mean = spec_mean
252
- self.spec_std = spec_std
253
- self.use_semseg = use_semseg
254
- self.override_target_length = override_target_length
255
- self.use_tags = use_tags
256
- self.extra_audio_masking = extra_audio_masking
257
- self.neg_audio = neg_audio
258
- self.audio_level = audio_level
259
-
260
- self.quad_mixup = quad_mixup
261
- self.bg_mixup = bg_mixup
262
- self.patch_mixup = patch_mixup
263
- self.patch_size = patch_size
264
-
265
- self.sample_rate = sample_rate
266
- self.n_label_frames = n_label_frames
267
-
268
- if self.use_audio_embed:
269
- self.all_embed_files = self._all_embed_files()
270
-
271
- if self.use_audio or self.use_spec:
272
- self.all_audio_files = self._all_audio_files()
273
-
274
- if self.use_frames:
275
- self.all_frame_files = self._all_frame_files()
276
-
277
- self.metadata = self._add_files_to_metadata(self._load_info(self.split))
278
-
279
- assert len(self.metadata) > 0
280
-
281
- def __len__(self):
282
- return len(self.metadata)
283
-
284
- @abstractmethod
285
- def _expected_num_frames(self) -> int:
286
- pass
287
-
288
- def get_audio_mask(self, real_length, padded_length, target_size):
289
- if not isinstance(real_length, torch.Tensor):
290
- real_length = torch.tensor(real_length)
291
- padded_length = torch.tensor(padded_length)
292
-
293
- n_frames = ((real_length / padded_length) * target_size).to(torch.int64)
294
- oh = F.one_hot(n_frames, num_classes=target_size + 1)
295
- if len(oh.shape) == 1:
296
- oh = oh.unsqueeze(0)
297
- return (1 - torch.cumsum(oh, dim=1))[:, :-1].to(torch.bool)
298
-
299
- def _base_get_item(self, item):
300
- id = self.metadata["id"].iloc[item]
301
- data_dict = {"metadata": {"id": id, "index": item}}
302
-
303
- if self.use_tags and "tags" in self.metadata:
304
- tags = torch.tensor(self.metadata["tags"].iloc[item])
305
- tag_oh = torch.zeros(self.num_tags, dtype=torch.float32)
306
- tag_oh[tags] += 1
307
- data_dict["tags"] = tag_oh
308
-
309
- if self.use_audio or self.use_spec:
310
- audio_file = self.metadata["audio_file"].iloc[item]
311
- data_dict["metadata"]["audio_file"] = audio_file
312
- loaded_waveform, obs_sr = torchaudio.load(audio_file)
313
- loaded_waveform = loaded_waveform[0]
314
-
315
- if self.neg_audio:
316
- neg_audio_file = self.metadata["audio_file"].iloc[torch.randint(0, len(self), size=(1,)).item()]
317
- data_dict["metadata"]["neg_audio_file"] = neg_audio_file
318
- neg_waveform, neg_obs_sr = torchaudio.load(neg_audio_file)
319
- neg_waveform = neg_waveform[0]
320
- else:
321
- neg_waveform, neg_obs_sr = None, None
322
-
323
- (waveform,
324
- spectrogram,
325
- audio_length,
326
- total_length,
327
- original_length,
328
- mask,
329
- pos_mask) = prep_waveform(
330
- loaded_waveform,
331
- obs_sr,
332
- self.target_length(),
333
- self.spec_mel_bins,
334
- self.spec_mean,
335
- self.spec_std,
336
- self.sample_rate,
337
- self.use_spec,
338
- False,
339
- self.extra_audio_masking,
340
- neg_waveform,
341
- neg_obs_sr,
342
- self.audio_level,
343
- self.audio_aug
344
- )
345
-
346
- if self.spec_transform is not None and spectrogram is not None:
347
- spectrogram = self.spec_transform(spectrogram)
348
-
349
- if self.audio_transform is not None:
350
- waveform = self.audio_transform(waveform)
351
-
352
- data_dict["audio"] = waveform
353
- data_dict[AUDIO_MASK] = mask
354
- data_dict[AUDIO_POS_MASK] = pos_mask
355
- data_dict["audio_length"] = audio_length
356
- data_dict["original_length"] = original_length
357
- data_dict["total_length"] = total_length
358
- if spectrogram is not None:
359
- data_dict["spec"] = spectrogram
360
-
361
- if mask.mean() < .04:
362
- return None
363
-
364
- if self.use_davenet_spec:
365
- from data.DavenetUtilities import davenet_load_audio
366
- audio_file = self.metadata["audio_file"].iloc[item]
367
- spec, n_frames = davenet_load_audio(audio_file)
368
- data_dict["davenet_spec"] = spec
369
-
370
- if self.use_fnac_spec:
371
- from featurizers.FNACAVL import load_spectrogram as fnac_load_spectrogram
372
- audio_file = self.metadata["audio_file"].iloc[item]
373
- data_dict["fnac_spec"] = fnac_load_spectrogram(audio_file, 3)
374
-
375
- if self.use_audio_embed:
376
- loaded = np.load(self.metadata["embed_file"].iloc[item])
377
- data_dict["audio_emb"] = loaded["feat"]
378
- data_dict["audio_length"] = loaded["audio_length"]
379
- data_dict["total_length"] = loaded["total_length"]
380
- data_dict["original_length"] = loaded["original_length"]
381
- data_dict[AUDIO_MASK] = self.get_audio_mask(
382
- data_dict["audio_length"],
383
- data_dict["total_length"],
384
- data_dict["audio_emb"].shape[-1]) \
385
- .squeeze().to(torch.float32)
386
- data_dict[AUDIO_POS_MASK] = data_dict[AUDIO_MASK].to(torch.float32)
387
-
388
- if self.use_frames:
389
-
390
- def get_frames(item):
391
- file_group = self.metadata["frame_files"].iloc[item]
392
- if self.n_frames is not None:
393
- selected_frames = torch.randperm(len(file_group))[:self.n_frames]
394
- file_group = [file_group[i] for i in selected_frames]
395
- data_dict["metadata"]["frame_files"] = file_group
396
- images = [Image.open(file).convert("RGB") for file in file_group]
397
-
398
- if self.frame_transform is not None:
399
- images = torch.cat([self.frame_transform(img).unsqueeze(0) for img in images], dim=0)
400
-
401
- return images, file_group
402
-
403
- no_mixup = 1.0 - (self.bg_mixup + self.quad_mixup + self.patch_mixup)
404
-
405
- mixup_type = sample_choice(
406
- ["quad", "bg", "patch", None],
407
- [self.quad_mixup, self.bg_mixup, self.patch_mixup, no_mixup]
408
- )
409
-
410
- if mixup_type == "quad":
411
- indices = [item] + torch.randint(0, len(self), size=(3,)).numpy().tolist()
412
- frames_and_files = [get_frames(i) for i in indices]
413
- file_group = frames_and_files[0][1]
414
- perm = torch.randperm(4)
415
- all_frames = [F.interpolate(frames_and_files[i][0], scale_factor=0.5, mode="bilinear") for i in
416
- perm]
417
- b, c, h, w = all_frames[0].shape
418
- indices = [indices[p] for p in perm]
419
- masks = [(torch.ones(b, 1, h, w) if index == item else torch.zeros(b, 1, h, w)) for index in
420
- indices]
421
-
422
- data_dict[IMAGE_INPUT] = grid_frames(all_frames)
423
- data_dict[IMAGE_MASK] = grid_frames(masks)
424
- elif mixup_type == "bg":
425
- neg_item = torch.randint(0, len(self), size=(1,)).item()
426
- neg_frame, _ = get_frames(neg_item)
427
- pos_frame, file_group = get_frames(item)
428
-
429
- b, c, h, w = neg_frame.shape
430
- neg_mask = torch.zeros(b, 1, h, w)
431
- pos_mask = torch.ones(b, 1, h, w)
432
-
433
- if torch.rand(1).item() > 0.5:
434
- bg_frame = neg_frame
435
- bg_mask = neg_mask
436
- fg_frame = F.interpolate(pos_frame, scale_factor=0.5, mode="bilinear")
437
- fg_mask = F.interpolate(pos_mask, scale_factor=0.5, mode="bilinear")
438
- else:
439
- bg_frame = pos_frame
440
- bg_mask = pos_mask
441
- fg_frame = F.interpolate(neg_frame, scale_factor=0.5, mode="bilinear")
442
- fg_mask = F.interpolate(neg_mask, scale_factor=0.5, mode="bilinear")
443
-
444
- start_h = torch.randint(0, h // 2, size=(1,))
445
- start_w = torch.randint(0, w // 2, size=(1,))
446
- bg_frame[:, :, start_h:start_h + fg_frame.shape[2], start_w:start_w + fg_frame.shape[3]] = fg_frame
447
- bg_mask[:, :, start_h:start_h + fg_frame.shape[2], start_w:start_w + fg_frame.shape[3]] = fg_mask
448
-
449
- data_dict["frames"] = bg_frame
450
- data_dict["image_masks"] = bg_mask
451
-
452
- elif mixup_type == "patch":
453
- neg_item = torch.randint(0, len(self), size=(1,)).item()
454
- neg_frame, _ = get_frames(neg_item)
455
- pos_frame, file_group = get_frames(item)
456
- frames, masks = create_mixed_image(pos_frame, neg_frame, self.patch_size)
457
- data_dict["frames"] = frames
458
- data_dict["image_masks"] = masks
459
-
460
- elif mixup_type is None:
461
- frames, file_group = get_frames(item)
462
-
463
- data_dict["frames"] = frames
464
- b, c, h, w = frames.shape
465
- data_dict["image_masks"] = torch.ones(b, 1, h, w)
466
- else:
467
- raise ValueError(f"Unknown mixup type {mixup_type}")
468
-
469
- if "original_length" in data_dict:
470
- if self._expected_num_frames() == 1:
471
- frame_nums = torch.tensor([0])
472
- else:
473
- frame_nums = torch.tensor([
474
- int(f.split("/")[-1].split("_")[-1].split(".")[0]) for f in file_group])
475
-
476
- data_dict["frame_nums"] = frame_nums
477
- frame_fracs = ((frame_nums + .5) / (self._expected_num_frames()))
478
- frame_position = (frame_fracs * data_dict["original_length"]) / data_dict["total_length"]
479
- data_dict["frame_position"] = frame_position
480
-
481
- if self.use_caption:
482
- if "word" in self.metadata:
483
- words = self.metadata["word"].iloc[item]
484
- start = self.metadata["start"].iloc[item]
485
- end = self.metadata["end"].iloc[item]
486
- if isinstance(words, float):
487
- words = [""]
488
- start = [0.0]
489
- end = [-1.0]
490
-
491
- data_dict["caption"] = {
492
- "words": words,
493
- "start": start,
494
- "end": end,
495
- }
496
- if "text" in self.metadata:
497
- data_dict["text"] = self.metadata["text"].iloc[item]
498
-
499
- if self.use_semseg:
500
- semseg_path = join(self._semseg_root(), self.metadata["semseg_file"].iloc[item])
501
- semseg = Image.open(semseg_path)
502
- if self.label_transform is not None:
503
- semseg = np.array(self.label_transform(semseg))
504
- data_dict["semseg"] = semseg
505
- data_dict["metadata"]["semseg_file"] = semseg_path
506
-
507
- # if hasattr(self, "num_classes"):
508
- # data_dict["num_pixels_per_class"] = F.one_hot(
509
- # torch.tensor(semseg).to(torch.int64), self.num_classes() + 1).sum(dim=[0, 1])
510
-
511
- return data_dict
512
-
513
- def __getitem__(self, item):
514
- try:
515
- data_dict = self._base_get_item(item)
516
- if self.use_hn:
517
- indices = torch.cat([self.metadata["hn0"].iloc[item], self.metadata["hn1"].iloc[item]], dim=0)
518
- neg_index = indices[torch.randint(0, indices.shape[0], (1,))]
519
- negative_dict = self._base_get_item(neg_index)
520
- data_dict["negatives"] = negative_dict
521
- return data_dict
522
- except (audioread.exceptions.NoBackendError, EOFError) as e:
523
- # raise e
524
- bad_path = self.metadata["audio_file"].iloc[item]
525
- print(e)
526
- print(f"Removing bad audio file {bad_path}")
527
- # os.remove(bad_path)
528
- return None
529
- except ValueError as e:
530
- # raise e
531
- bad_path = self.metadata["audio_file"].iloc[item]
532
- if "Input signal length=0" in str(e):
533
- print(e)
534
- print(f"Removing bad file {bad_path} due to input signal length=0")
535
- # os.remove(bad_path)
536
- return None
537
- except OSError as e:
538
- # raise e
539
- bad_paths = self.metadata["frame_files"].iloc[item]
540
- for bad_path in bad_paths:
541
- print(e)
542
- print(f"Removing bad frame file {bad_path}")
543
- return None
544
- except RuntimeError as e:
545
- # raise e
546
- bad_path = self.metadata["audio_file"].iloc[item]
547
- print(e)
548
- print(f"Removing bad audio file {bad_path}")
549
- # os.remove(bad_path)
550
- return None
551
-
552
-
553
- class PlacesAudio(AVDataset):
554
-
555
- def _load_info(self, split) -> pd.DataFrame:
556
- df = pd.read_json(join(os.path.dirname(self._audio_root()), "metadata", f"{split}.json"))
557
- df["id"] = df["data"].apply(lambda d: d["wav"][5:-4])
558
-
559
- if self.use_caption:
560
- if split == "train":
561
- word_df = pd.read_json(
562
- join(os.path.dirname(self._audio_root()), "metadata", f"word-alignment-{split}.json")
563
- )
564
- else:
565
- word_df = pd.read_csv(
566
- join(os.path.dirname(self._audio_root()), "metadata", f"word-alignment-{split}.csv")) \
567
- .groupby("id").aggregate(lambda g: list(g)).reset_index().drop("Unnamed: 0", axis=1)
568
- df = pd.merge(df, word_df, on="id", how="outer")
569
- return df
570
-
571
- def _missing_threshold(self) -> float:
572
- # return 0.0
573
- return 0.97 # TODO fix
574
-
575
- def _expected_num_frames(self):
576
- return 1
577
-
578
- def default_target_length(self) -> int:
579
- return 20
580
-
581
- def _frame_root(self) -> str:
582
- return join(os.path.dirname(self.root), "places_subset")
583
-
584
- def _audio_root(self) -> str:
585
- return join(self.root, "wavs")
586
-
587
- def _embed_root(self) -> str:
588
- return join(self.root, "embedding", self.audio_embed_model)
589
-
590
- def _dataset_folder(self) -> str:
591
- return "PlacesAudio_400k_distro"
592
-
593
- def _get_audio_file(self, row) -> str:
594
- return join(self._audio_root(), row["id"] + ".wav")
595
-
596
- def _get_frame_files(self, row) -> List[str]:
597
- return [join(self._frame_root(), row["data"]["image"])]
598
-
599
- def _get_embed_file(self, row) -> str:
600
- return join(self._embed_root(), row["id"] + ".npz")
601
-
602
-
603
- class AudioSet(AVDataset):
604
- def _expected_num_frames(self):
605
- return 10
606
-
607
- def default_target_length(self) -> int:
608
- return 20
609
-
610
- def _dataset_folder(self) -> str:
611
- return "audioset-raw"
612
-
613
- def _missing_threshold(self) -> float:
614
- if self.split == "val" or self.split == "test":
615
- return 0.02
616
- else:
617
- return 0.17
618
-
619
- def train_seg_file(self):
620
- return "unbalanced_train_segments.csv"
621
-
622
- def _load_info(self, split) -> pd.DataFrame:
623
- if split == "train":
624
- df = pd.read_csv(join(self.root, "metadata", self.train_seg_file()))
625
- elif split == "val" or split == "test":
626
- df = pd.read_csv(join(self.root, "metadata", "eval_segments_subset.csv"))
627
- else:
628
- raise ValueError(f"Unknown split {split}")
629
-
630
- labels = pd.read_csv(join(self.root, "metadata", "class_labels_indices.csv"))
631
- mid_to_index = dict(zip(labels["mid"], labels["index"]))
632
- df["tags"] = df["positive_labels"].apply(lambda l: [mid_to_index[e] for e in l.strip('"').split(",")])
633
-
634
- self.num_tags = max(*[i for k, i in mid_to_index.items()]) + 1
635
- df["id"] = df.apply(lambda r: f"{r.YTID}_{r.start_seconds}_{r.end_seconds}", axis=1)
636
- return df
637
-
638
- def _frame_root(self) -> str:
639
- return join(self.root, "frames")
640
-
641
- def _audio_root(self) -> str:
642
- return join(self.root, "audio")
643
-
644
- def _all_frame_files(self) -> Set[str]:
645
- frame_files = set()
646
-
647
- for entry in os.scandir(self._frame_root()):
648
- if entry.is_file():
649
- frame_files.add(entry.path)
650
- elif entry.is_dir():
651
- for subentry in os.scandir(entry.path):
652
- if subentry.is_file():
653
- frame_files.add(subentry.path)
654
-
655
- return frame_files
656
-
657
- def _all_audio_files(self) -> Set[str]:
658
- return set(entry.path for entry in os.scandir(self._audio_root()) if entry.is_file())
659
-
660
- def _all_embed_files(self) -> Set[str]:
661
- return set(entry.path for entry in os.scandir(self._embed_root()) if entry.is_file())
662
-
663
- def _embed_root(self) -> str:
664
- return join(self.root, "embedding", self.audio_embed_model)
665
-
666
- def prefix(self):
667
- return ""
668
-
669
- def _get_audio_file(self, row) -> str:
670
- return f"{self.root}/audio/{self.prefix()}{row.id}.mp3"
671
-
672
- def _get_frame_files(self, row) -> List[str]:
673
- return [f"{self.root}/frames/frame_{fn}/{self.prefix()}{row.id}.jpg" for fn in range(10)]
674
-
675
- def _get_embed_file(self, row) -> str:
676
- return f"{self.root}/embedding/{self.audio_embed_model}/{self.prefix()}{row.id}.npz"
677
-
678
-
679
- class AudioSetEval(AudioSet):
680
-
681
- def _dataset_folder(self) -> str:
682
- return "audioset-eval"
683
-
684
- def _get_frame_files(self, row) -> List[str]:
685
- base_path = f"{self.root}/frames/{self.prefix()}{row.id}_"
686
- return [base_path + f"{fn}.jpg" for fn in range(10)]
687
-
688
- def prefix(self):
689
- return ""
690
-
691
-
692
- class ADE20K(AVDataset):
693
-
694
- def _split_name(self, split):
695
- if split == "val":
696
- return "validation"
697
- elif split == "train":
698
- return "training"
699
- else:
700
- raise ValueError(f"Unknown split name {split}")
701
-
702
- def _load_info(self, split) -> pd.DataFrame:
703
- df = pd.read_json(join(self.root, "metadata_with_caption_dedup.json"))
704
- df["id"] = df["image"]
705
- df = df[df["image"].apply(lambda f: f.split("/")[0] == split)]
706
-
707
- if self.use_caption:
708
- df["word"] = df["caption"].apply(lambda c: c["words"])
709
- df["start"] = df["caption"].apply(lambda c: c["start"])
710
- df["end"] = df["caption"].apply(lambda c: c["end"])
711
- df["text"] = df["word"].apply(lambda l: " ".join(l))
712
- return df
713
-
714
- def _missing_threshold(self) -> float:
715
- return 0.03
716
-
717
- def _expected_num_frames(self):
718
- return 1
719
-
720
- def default_target_length(self) -> int:
721
- return 20
722
-
723
- def _dataset_folder(self) -> str:
724
- return "ADE20K"
725
-
726
- def _frame_root(self) -> str:
727
- return join(self.root, "frames")
728
-
729
- def _audio_root(self) -> str:
730
- return join(self.root, "audio")
731
-
732
- def _semseg_root(self) -> str:
733
- return join(self.root, "annotations")
734
-
735
- def _embed_root(self) -> str:
736
- return join(self.root, "embedding", self.audio_embed_model)
737
-
738
- def _get_audio_file(self, row) -> str:
739
- return join(self._audio_root(), row["audio"])
740
-
741
- def _get_frame_files(self, row) -> List[str]:
742
- return [join(self._frame_root(), row["image"])]
743
-
744
- def _get_semseg_file(self, row) -> str:
745
- return join(self._semseg_root(), row["seg"])
746
-
747
- def _get_embed_file(self, row) -> str:
748
- return join(self._embed_root(), row["image"].replace(".jpg", ".npz"))
749
-
750
- def num_classes(self):
751
- return 3662
752
-
753
-
754
- class ADE20KPromptedBase(AVDataset):
755
-
756
- def _expected_num_frames(self):
757
- return 1
758
-
759
- def default_target_length(self) -> int:
760
- return 20
761
-
762
- def _frame_root(self) -> str:
763
- return join(self.root, "frames")
764
-
765
- def _audio_root(self) -> str:
766
- return join(self.root, "audio")
767
-
768
- def _semseg_root(self) -> str:
769
- return join(self.root, "annotations")
770
-
771
- def _embed_root(self) -> str:
772
- return join(self.root, "embedding", self.audio_embed_model)
773
-
774
- def _get_frame_files(self, row) -> List[str]:
775
- return [join(self._frame_root(), row["image_location"])]
776
-
777
- def _get_semseg_file(self, row) -> str:
778
- return join(self._semseg_root(), row["image_location"].replace(".jpg", "_seg.png"))
779
-
780
- def _get_embed_file(self, row) -> str:
781
- return join(self._embed_root(), row["image_location"].replace(".jpg", ".npz"))
782
-
783
- def num_classes(self):
784
- return 3662
785
-
786
- def _missing_threshold(self) -> float:
787
- return 0.0
788
-
789
-
790
- class ADE20KSpeechPrompted(ADE20KPromptedBase):
791
-
792
- def _get_audio_file(self, row) -> str:
793
- return join(self._audio_root(), row["speech_prompt_file"].split("/")[-1])
794
-
795
- def _dataset_folder(self) -> str:
796
- return "ADE20KSpeechPrompted"
797
-
798
- def _audio_root(self) -> str:
799
- # return join(self.root, "audio-noise-10") # TODO Remove
800
- return join(self.root, "audio") # TODO Remove
801
-
802
- def _load_info(self, split) -> pd.DataFrame:
803
- df = pd.read_csv(join(self.root, "prompted_segmentation.csv"))
804
- df = df[df["speech_prompt_file"].apply(lambda s: isinstance(s, str))]
805
- df = df[df["ade_class_id"].apply(lambda id: id != 0)]
806
- df["id"] = df["image_location"]
807
- return df
808
-
809
-
810
- class ADE20KSoundPrompted(ADE20KPromptedBase):
811
-
812
- def _get_audio_file(self, row) -> str:
813
- return join(self._audio_root(), row["vggsound_file"].split("/")[-1])
814
-
815
- def _dataset_folder(self) -> str:
816
- return "ADE20KSoundPrompted"
817
-
818
- def _load_info(self, split) -> pd.DataFrame:
819
- df = pd.read_csv(join(self.root, "prompted_segmentation.csv"))
820
- df = df[df["vggsound_file"].apply(lambda s: isinstance(s, str))]
821
- df = df[df["ade_class_id"].apply(lambda id: id != 0)]
822
- df["id"] = df["image_location"]
823
- return df
824
-
825
-
826
- class PlacesAndAudioSet(Dataset):
827
-
828
- def __init__(self, **kwargs):
829
- self.ds1 = PlacesAudio(**kwargs, n_frames=1)
830
- self.ds2 = AudioSet(**kwargs, n_frames=1)
831
-
832
- def __len__(self):
833
- return len(self.ds1)
834
-
835
- def __getitem__(self, item):
836
- if torch.rand(1).item() > .5:
837
- d = self.ds2[torch.randint(0, len(self.ds2) - 1, size=(1,)).item()]
838
- if d is not None:
839
- d["source"] = 1
840
- else:
841
- d = self.ds1[item]
842
- if d is not None:
843
- d["source"] = 0
844
- return d
845
-
846
-
847
- class AVDataModule(pl.LightningDataModule):
848
- def __init__(self,
849
- dataset_name,
850
- load_size,
851
- image_aug,
852
- audio_aug,
853
- extra_audio_masking,
854
- audio_model_type,
855
- pytorch_data_dir,
856
- use_cached_embs,
857
- batch_size,
858
- num_workers,
859
- audio_level,
860
- neg_audio,
861
- data_for_plotting,
862
- use_original_val_set,
863
- use_extra_val_sets,
864
- quad_mixup,
865
- bg_mixup,
866
- patch_mixup,
867
- patch_size,
868
- **kwargs):
869
-
870
- super().__init__()
871
- self.dataset_name = dataset_name
872
- self.load_size = load_size
873
- self.image_aug = image_aug
874
- self.audio_aug = audio_aug
875
- self.extra_audio_masking = extra_audio_masking
876
- self.audio_model_type = audio_model_type
877
- self.pytorch_data_dir = pytorch_data_dir
878
- self.use_cached_embs = use_cached_embs
879
- self.batch_size = batch_size
880
- self.num_workers = num_workers
881
- self.data_for_plotting = data_for_plotting
882
- self.audio_level = audio_level
883
- self.neg_audio = neg_audio
884
-
885
- self.quad_mixup = quad_mixup
886
- self.bg_mixup = bg_mixup
887
- self.patch_mixup = patch_mixup
888
- self.patch_size = patch_size
889
-
890
- self.loader_args = dict(
891
- num_workers=self.num_workers,
892
- batch_size=self.batch_size,
893
- )
894
- self.save_hyperparameters()
895
- self.extra_args = kwargs
896
-
897
- self.use_original_val_set = use_original_val_set
898
- self.use_extra_val_sets = use_extra_val_sets
899
-
900
- def maybe_unpack(self, remove_source):
901
- targets = [
902
- (
903
- join(self.pytorch_data_dir, "audioset-subset", "frame_archives"),
904
- join(self.pytorch_data_dir, "audioset-subset", "frames"),
905
- 1
906
- ),
907
- (
908
- join(self.pytorch_data_dir, "audioset-raw", "frame_archives"),
909
- join(self.pytorch_data_dir, "audioset-raw", "frames"),
910
- 4
911
- ),
912
- (
913
- join(self.pytorch_data_dir, "audioset-raw", "audio_archives"),
914
- join(self.pytorch_data_dir, "audioset-raw", "audio"),
915
- 1
916
- ),
917
-
918
- ]
919
-
920
- for (archive_dir, target_dir, n_parts) in targets:
921
- if not os.path.exists(target_dir) and os.path.exists(archive_dir):
922
- print(f"Could not find {target_dir}, attempting to unpack archives")
923
- if os.path.exists(archive_dir):
924
- untar_all(archive_dir, target_dir, remove_source)
925
- else:
926
- raise RuntimeError(f"Could not find archive folder: {archive_dir}")
927
-
928
- def get_dataset_by_name(self, name, stage, data_for_plotting, n_frames=None):
929
-
930
- if name == "vggss":
931
- resize_op = T.Resize((self.load_size, self.load_size), Image.BILINEAR)
932
- else:
933
- resize_op = T.Resize(self.load_size, Image.BILINEAR)
934
-
935
- img_transform = T.Compose([
936
- resize_op,
937
- T.CenterCrop(self.load_size),
938
- T.ToTensor(),
939
- norm])
940
-
941
- if self.image_aug:
942
- train_img_transform = T.Compose([
943
- T.RandomResizedCrop(self.load_size),
944
- T.RandomHorizontalFlip(),
945
- T.ColorJitter(.2, .2, .2, .2),
946
- T.RandomGrayscale(),
947
- T.ToTensor(),
948
- norm])
949
- val_img_transform = img_transform
950
- else:
951
- train_img_transform = img_transform
952
- val_img_transform = img_transform
953
-
954
- if self.audio_aug:
955
- train_audio_aug = True
956
- val_audio_aug = False
957
- else:
958
- train_audio_aug = False
959
- val_audio_aug = False
960
-
961
- if self.audio_model_type == "hubert":
962
- from featurizers.Hubert import HubertAudioTransform
963
- audio_transform = HubertAudioTransform()
964
- else:
965
- audio_transform = None
966
-
967
- if self.audio_model_type == "passt":
968
- sample_rate = 32000
969
- else:
970
- sample_rate = 16000
971
-
972
- if not self.use_cached_embs:
973
- if self.audio_model_type == "hubert":
974
- self.extra_args["use_audio"] = True
975
- elif self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
976
- self.extra_args["use_spec"] = True
977
- elif self.audio_model_type == "davenet":
978
- self.extra_args["use_audio"] = True
979
- self.extra_args["use_davenet_spec"] = True
980
- elif self.audio_model_type == "fnac":
981
- self.extra_args["use_audio"] = True
982
- self.extra_args["use_fnac_spec"] = True
983
- else:
984
- raise ValueError(f"Unknown audio model type {self.audio_model_type}")
985
-
986
- if self.audio_model_type == "cavmae" or self.audio_model_type == "cavmae-mixed":
987
- self.extra_args["spec_mean"] = -5.081
988
- self.extra_args["spec_std"] = 4.4849
989
- elif self.audio_model_type == "imagebind":
990
- self.extra_args["spec_mean"] = -4.268
991
- self.extra_args["spec_std"] = 9.138
992
-
993
- # if self.audio_model_type in {"audiomae", "audiomae-finetune", "cavmae"} \
994
- # and "override_target_length" not in self.extra_args:
995
- if "override_target_length" not in self.extra_args:
996
- self.extra_args["override_target_length"] = 10
997
-
998
- data_args = dict(
999
- root=self.pytorch_data_dir,
1000
- use_frames=True,
1001
- audio_transform=audio_transform,
1002
- sample_rate=sample_rate,
1003
- audio_level=self.audio_level,
1004
- **self.extra_args
1005
- )
1006
-
1007
- if n_frames is not None:
1008
- data_args["n_frames"] = n_frames
1009
-
1010
- train_args = dict(
1011
- frame_transform=train_img_transform,
1012
- extra_audio_masking=self.extra_audio_masking,
1013
- neg_audio=self.neg_audio,
1014
- quad_mixup=self.quad_mixup,
1015
- bg_mixup=self.bg_mixup,
1016
- patch_mixup=self.patch_mixup,
1017
- patch_size=self.patch_size,
1018
- audio_aug=train_audio_aug
1019
- )
1020
- val_args = dict(
1021
- frame_transform=val_img_transform,
1022
- audio_aug=val_audio_aug
1023
- )
1024
-
1025
- if data_for_plotting:
1026
- val_args["use_audio"] = True
1027
- val_args["use_spec"] = True
1028
-
1029
- if "ade" in name:
1030
- label_transform = T.Compose([
1031
- T.Resize(self.load_size, Image.NEAREST),
1032
- T.CenterCrop(self.load_size),
1033
- prep_ade_label
1034
- ])
1035
- else:
1036
- label_transform = T.Compose([
1037
- T.Resize(self.load_size, Image.NEAREST),
1038
- T.CenterCrop(self.load_size)
1039
- ])
1040
-
1041
- val_args["use_audio"] = True
1042
- val_args["label_transform"] = label_transform
1043
-
1044
- if name == "places-audio":
1045
- dataset_constructor = PlacesAudio
1046
- elif name == "mixed-full":
1047
- dataset_constructor = PlacesAndAudioSet
1048
- elif name == "audio-set-full":
1049
- dataset_constructor = AudioSet
1050
- elif name == "audio-set-eval":
1051
- dataset_constructor = AudioSetEval
1052
- elif name == "ade":
1053
- val_args["use_semseg"] = True
1054
- dataset_constructor = ADE20K
1055
- elif name == "ade-speech-prompted":
1056
- val_args["use_semseg"] = True
1057
- dataset_constructor = ADE20KSpeechPrompted
1058
- elif name == "ade-sound-prompted":
1059
- val_args["use_semseg"] = True
1060
- dataset_constructor = ADE20KSoundPrompted
1061
- else:
1062
- raise ValueError(f"Unknown dataset name {name}")
1063
-
1064
- data_args["use_audio_embed"] = self.use_cached_embs
1065
- data_args["audio_embed_model"] = self.audio_model_type
1066
-
1067
- if stage == "full":
1068
- val_dataset = dataset_constructor(split="val", **{**data_args, **val_args})
1069
- train_dataset = dataset_constructor(split="train", **{**data_args, **val_args})
1070
- return ConcatDataset([train_dataset, val_dataset])
1071
- elif stage == "fit":
1072
- return dataset_constructor(split="train", **{**data_args, **train_args})
1073
- elif stage == "validate":
1074
- return dataset_constructor(split="val", **{**data_args, **val_args})
1075
- else:
1076
- raise ValueError(f"Unknown stage: {stage}")
1077
-
1078
- def _maybe_subset(self, dataset, length):
1079
- if len(dataset) > length and self.dataset_name not in {"ade-sound-prompted", "ade-speech-prompted", "vggss"}:
1080
- print("Using a subset of validation data")
1081
- return Subset(dataset, generate_subset(len(dataset), length))
1082
- else:
1083
- print("Not using val subset")
1084
- return dataset
1085
-
1086
- def _make_val_datasets(self):
1087
- val_sets = []
1088
- if self.use_original_val_set:
1089
- val_sets.append(self._maybe_subset(self.get_dataset_by_name(
1090
- self.dataset_name, "validate", self.data_for_plotting), 1000))
1091
-
1092
- if self.use_extra_val_sets:
1093
- val_sets.append(self._maybe_subset(self.get_dataset_by_name(
1094
- "places-audio", "validate", self.data_for_plotting), 1000))
1095
- val_sets.append(self._maybe_subset(self.get_dataset_by_name(
1096
- "audio-set-eval", "validate", False, n_frames=1), 1000))
1097
- val_sets.append(self.get_dataset_by_name(
1098
- "ade-speech-prompted", "validate", True))
1099
- val_sets.append(self.get_dataset_by_name(
1100
- "ade-sound-prompted", "validate", self.data_for_plotting))
1101
-
1102
- return val_sets
1103
-
1104
- def setup(self, stage: str):
1105
- if stage == "full":
1106
- self.full_dataset = self.get_dataset_by_name(self.dataset_name, stage, self.data_for_plotting)
1107
- elif stage == "fit":
1108
- self.train_dataset = self.get_dataset_by_name(self.dataset_name, stage, self.data_for_plotting)
1109
- self.val_datasets = self._make_val_datasets()
1110
- elif stage == "validate":
1111
- self.val_datasets = self._make_val_datasets()
1112
- else:
1113
- raise ValueError(f"Unknown stage: {stage}")
1114
-
1115
- def train_dataloader(self):
1116
- return DataLoader(self.train_dataset, shuffle=True, **self.loader_args, collate_fn=custom_coallate)
1117
-
1118
- def subsampled_train_dataloader(self, k=5000):
1119
- if len(self.train_dataset) > k:
1120
- ds = Subset(self.train_dataset, generate_subset(len(self.train_dataset), k))
1121
- else:
1122
- ds = self.train_dataset
1123
-
1124
- return DataLoader(ds, shuffle=True, **self.loader_args, collate_fn=custom_coallate)
1125
-
1126
- def val_dataloader(self):
1127
- return [
1128
- DataLoader(dataset, shuffle=False, **self.loader_args, collate_fn=custom_coallate)
1129
- for dataset in self.val_datasets
1130
- ]
1131
-
1132
- def full_dataloader(self):
1133
- return DataLoader(self.full_dataset, shuffle=False, **self.loader_args, collate_fn=custom_coallate)
1134
-
1135
-
1136
- def generate_subset(n, batch, seed=0):
1137
- np.random.seed(seed)
1138
- return np.random.permutation(n)[:batch]
1139
-
1140
-
1141
- def prep_ade_label(img):
1142
- seg = np.array(img)
1143
- class_labels = (seg[:, :, 0] / 10).astype(np.int32) * 256 + (seg[:, :, 1].astype(np.int32))
1144
- return class_labels
1145
-
1146
-
1147
- def maybe_replace(e, not_none):
1148
- if e is not None:
1149
- return e
1150
- else:
1151
- print("Warning found a None in the dataset indicitive of a loading failure, replacing it with another item")
1152
- return not_none[0]
1153
-
1154
-
1155
- empty_caption = {
1156
- "words": [],
1157
- "start": [],
1158
- "end": [],
1159
- }
1160
-
1161
-
1162
- def custom_coallate(l):
1163
- if l is None:
1164
- return l
1165
-
1166
- not_none = [e for e in l if e is not None]
1167
- assert len(not_none) > 0
1168
-
1169
- l = [maybe_replace(e, not_none) for e in l]
1170
-
1171
- to_merge = {}
1172
-
1173
- def pop_or_default(dict, k, default):
1174
- if k in dict:
1175
- return dict.pop(k)
1176
- else:
1177
- print(f"WARNING: Could not find {k}, using {default}")
1178
- return default
1179
-
1180
- if "caption" in l[0]:
1181
- to_merge["caption"] = [pop_or_default(l[i], "caption", empty_caption) for i in range(len(l))]
1182
-
1183
- if "text" in l[0]:
1184
- to_merge["text"] = [pop_or_default(l[i], "text", "") for i in range(len(l))]
1185
-
1186
- result = default_collate(l)
1187
-
1188
- return {**result, **to_merge}
1189
-
1190
-
1191
- if __name__ == "__main__":
1192
-
1193
- from featurizers.Hubert import HubertAudioTransform
1194
-
1195
- pytorch_data_dir = "/pytorch-data"
1196
- dataset_constructor = PlacesAudio
1197
- split = "val"
1198
-
1199
- img_transform = T.Compose([
1200
- T.Resize(224, Image.BILINEAR),
1201
- T.CenterCrop(224),
1202
- T.ToTensor(),
1203
- norm])
1204
-
1205
- video_transform = T.Compose([
1206
- T.Resize(224, Image.BILINEAR),
1207
- T.CenterCrop(224),
1208
- norm])
1209
-
1210
- label_transform = T.Compose([
1211
- T.Resize(224, Image.NEAREST),
1212
- T.CenterCrop(224)
1213
- ])
1214
-
1215
- audio_transform = HubertAudioTransform()
1216
-
1217
- data_args = dict(
1218
- root=pytorch_data_dir,
1219
- frame_transform=img_transform,
1220
- use_frames=True,
1221
- use_spec=True,
1222
- use_audio=True,
1223
- use_caption=False,
1224
- use_semseg=False,
1225
- label_transform=label_transform,
1226
- audio_transform=audio_transform,
1227
- use_audio_embed=False,
1228
- audio_embed_model="audiomae",
1229
- extra_audio_masking=False,
1230
- neg_audio=False,
1231
- override_target_length=10,
1232
- audio_level=False,
1233
- quad_mixup=.3,
1234
- patch_mixup=.3,
1235
- bg_mixup=.3,
1236
- )
1237
-
1238
-
1239
- def return_datasets(dataset_constructor, split):
1240
- dataset = dataset_constructor(split=split, **data_args)
1241
- return dataset
1242
-
1243
-
1244
- train_ds = return_datasets(dataset_constructor, split)
1245
-
1246
- print(len(train_ds))
1247
- train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, num_workers=36, collate_fn=custom_coallate)
1248
- for batch in tqdm(train_loader):
1249
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/data/__init__.py DELETED
File without changes
DenseAV/denseav/data/make_tarballs.py DELETED
@@ -1,108 +0,0 @@
1
- import glob
2
- import os
3
- import tarfile
4
- from glob import glob
5
- from io import BytesIO
6
- from os.path import join
7
-
8
- from torch.utils.data import Dataset, DataLoader
9
- from tqdm import tqdm
10
- from pathlib import Path
11
-
12
- from denseav.shared import batch
13
-
14
- import tempfile
15
- import shutil
16
-
17
-
18
- class Tarballer(Dataset):
19
-
20
- def __init__(self, source, target, n):
21
- source_path = Path(source)
22
- self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()]
23
- assert (len(self.frames) > 0)
24
- self.source = source
25
- self.target_dir = target
26
- self.batched = list(batch(self.frames, n))
27
- os.makedirs(self.target_dir, exist_ok=True)
28
-
29
- def __len__(self):
30
- return len(self.batched)
31
-
32
- def __getitem__(self, item):
33
- with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar:
34
- for relpath in self.batched[item]:
35
- abs_path = os.path.join(self.source, str(relpath)) # Convert to string here
36
- with open(abs_path, "rb") as file:
37
- file_content = file.read()
38
- info = tarfile.TarInfo(name=str(relpath)) # Convert to string here
39
- info.size = len(file_content)
40
- tar.addfile(info, fileobj=BytesIO(file_content))
41
-
42
- return 0
43
-
44
-
45
- class UnTarballer:
46
-
47
- def __init__(self, archive_dir, target_dir, remove_source=False):
48
- self.tarballs = sorted(glob(join(archive_dir, "*.tar")))
49
- self.target_dir = target_dir
50
- self.remove_source = remove_source # New flag to determine if source tarball should be removed
51
- os.makedirs(self.target_dir, exist_ok=True)
52
-
53
- def __len__(self):
54
- return len(self.tarballs)
55
-
56
- def __getitem__(self, item):
57
- with tarfile.open(self.tarballs[item], "r") as tar:
58
- # Create a unique temporary directory inside the target directory
59
- with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname:
60
- tar.extractall(tmpdirname) # Extract to the temporary directory
61
-
62
- # Move contents from temporary directory to final target directory
63
- for src_dir, dirs, files in os.walk(tmpdirname):
64
- dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1)
65
- os.makedirs(dst_dir, exist_ok=True)
66
- for file_ in files:
67
- src_file = os.path.join(src_dir, file_)
68
- dst_file = os.path.join(dst_dir, file_)
69
- shutil.move(src_file, dst_file)
70
-
71
- # Remove the source tarball if the flag is set to True
72
- if self.remove_source:
73
- os.remove(self.tarballs[item])
74
-
75
- return 0
76
-
77
- def untar_all(archive_dir, target_dir, remove_source):
78
- loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24)
79
- for _ in tqdm(loader):
80
- pass
81
-
82
-
83
- if __name__ == "__main__":
84
- # loader = DataLoader(Tarballer(
85
- # join("/pytorch-data", "audioset-raw", "audio"),
86
- # join("/pytorch-data", "audioset-raw", "audio_archives")
87
- # ), num_workers=24)
88
-
89
- # loader = DataLoader(Tarballer(
90
- # join("/pytorch-data", "audioset-raw", "frames"),
91
- # join("/pytorch-data", "audioset-raw", "frame_archives"),
92
- # 5000
93
- # ), num_workers=24)
94
-
95
- # loader = DataLoader(Tarballer(
96
- # join("/pytorch-data", "ADE20KLabels"),
97
- # join("/pytorch-data", "ADE20KLabelsAr"),
98
- # 100
99
- # ), num_workers=24)
100
- #
101
- # for _ in tqdm(loader):
102
- # pass
103
- #
104
- # #
105
- #
106
- untar_all(
107
- join("/pytorch-data", "audioset-raw", "frame_archives"),
108
- join("/pytorch-data", "audioset-raw", "frames_4"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/eval_utils.py DELETED
@@ -1,135 +0,0 @@
1
- import json
2
- from collections import defaultdict
3
-
4
- import matplotlib.pyplot as plt
5
- import numpy as np
6
- import torch
7
- import torch.nn.functional as F
8
- from torchmetrics.functional.classification import binary_average_precision
9
- from tqdm import tqdm
10
-
11
- from constants import *
12
- from denseav.shared import unnorm, remove_axes
13
-
14
-
15
- def prep_heatmap(sims, masks, h, w):
16
- masks = masks.to(torch.float32)
17
- hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1)
18
- hm -= hm.min()
19
- hm /= hm.max()
20
- return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1)
21
-
22
-
23
- def iou(prediction, target):
24
- prediction = prediction > 0.0
25
- target = target > 0.5
26
- intersection = torch.logical_and(prediction, target).sum().float()
27
- union = torch.logical_or(prediction, target).sum().float()
28
- if union == 0:
29
- return 1.0
30
- return (intersection / union).item() # Convert to Python scalar
31
-
32
-
33
- def multi_iou(prediction, target, k=20):
34
- prediction = torch.tensor(prediction)
35
- target = torch.tensor(target)
36
- target = target > 0.5
37
-
38
- thresholds = torch.linspace(prediction.min(), prediction.max(), k)
39
- hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1)
40
- target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape)
41
-
42
- # Calculate IoU for each threshold
43
- intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
44
- union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
45
- union = torch.where(union == 0, torch.tensor(1.0), union) # Avoid division by zero
46
- iou_scores = intersection / union
47
-
48
- # Find the best IoU and corresponding threshold
49
- best_iou, best_idx = torch.max(iou_scores, dim=0)
50
- # best_threshold = thresholds[best_idx]
51
- # print(best_threshold)
52
- return best_iou # , best_threshold.item()
53
-
54
-
55
- def get_paired_heatmaps(
56
- model,
57
- results,
58
- class_ids,
59
- timing,
60
- class_names=None):
61
- sims = model.sim_agg.get_pairwise_sims(
62
- results,
63
- raw=False,
64
- agg_sim=False,
65
- agg_heads=True
66
- ).squeeze(1).mean(-2)
67
-
68
- prompt_classes = torch.tensor(list(class_ids))
69
- gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1)
70
- basic_masks = results[AUDIO_MASK] # BxT
71
- _, fullh, fullw = gt.shape
72
- basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw)
73
-
74
- if timing is not None:
75
- prompt_timing = np.array(list(timing))
76
- raw_timing = torch.tensor([json.loads(t) for t in prompt_timing])
77
- timing = torch.clone(raw_timing)
78
- timing[:, 0] -= .2
79
- timing[:, 1] += .2
80
- total_length = (results['total_length'] / 16000)[0]
81
- fracs = timing / total_length
82
- bounds = basic_masks.shape[1] * fracs
83
- bounds[:, 0] = bounds[:, 0].floor()
84
- bounds[:, 1] = bounds[:, 1].ceil()
85
- bounds = bounds.to(torch.int64)
86
- advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks)
87
- advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw)
88
-
89
- metrics = defaultdict(list)
90
- unique_classes = torch.unique(prompt_classes)
91
-
92
- should_plot = class_names is not None
93
-
94
- if should_plot:
95
- prompt_names = np.array(list(class_names))
96
-
97
- for prompt_class in tqdm(unique_classes):
98
- subset = torch.where(prompt_classes == prompt_class)[0]
99
- gt_subset = gt[subset]
100
- basic_subset = basic_heatmaps[subset]
101
- metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten()))
102
- metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten()))
103
-
104
- if timing is not None:
105
- advanced_subset = advanced_heatmaps[subset]
106
- metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten()))
107
- metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten()))
108
-
109
- if should_plot:
110
- prompt_class_subset = prompt_classes[subset]
111
- name_subset = prompt_names[subset]
112
- print(prompt_class, name_subset, prompt_class_subset)
113
- n_imgs = min(len(subset), 5)
114
- if n_imgs > 1:
115
- fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3))
116
- frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1)
117
- semseg_subset = results["semseg"][subset]
118
- for img_num in range(n_imgs):
119
- axes[img_num, 0].imshow(frame_subset[img_num])
120
- axes[img_num, 1].imshow(basic_subset[img_num])
121
- axes[img_num, 2].imshow(advanced_subset[img_num])
122
- axes[img_num, 3].imshow(gt_subset[img_num])
123
- axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none')
124
-
125
- axes[0, 0].set_title("Image")
126
- class_name = name_subset[0].split(",")[0]
127
- axes[0, 1].set_title(f"{class_name} Basic Heatmap")
128
- axes[0, 2].set_title(f"{class_name} Advanced Heatmap")
129
- axes[0, 3].set_title("True Mask")
130
- axes[0, 4].set_title("True Seg")
131
- remove_axes(axes)
132
- plt.tight_layout()
133
- plt.show()
134
-
135
- return metrics, unique_classes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/evaluate.py DELETED
@@ -1,87 +0,0 @@
1
- from os.path import join
2
- import hydra
3
- from omegaconf import DictConfig, OmegaConf
4
- from pytorch_lightning import Trainer
5
- from pytorch_lightning import seed_everything
6
- from pytorch_lightning.loggers import TensorBoardLogger
7
- from denseav.data.AVDatasets import AVDataModule
8
- from denseav.shared import load_trained_model
9
-
10
-
11
- @hydra.main(config_path="configs", config_name="av_align.yaml")
12
- def my_app(cfg: DictConfig) -> None:
13
- from saved_models import saved_model_dict
14
-
15
- seed_everything(0)
16
- print(OmegaConf.to_yaml(cfg))
17
-
18
- models_to_eval = [
19
- "denseav_language",
20
- "denseav_sound",
21
- ]
22
-
23
- checkpoint_dir = "../checkpoints"
24
- saved_models = saved_model_dict(checkpoint_dir)
25
- for model_name in models_to_eval:
26
- model_info = saved_models[model_name]
27
- extra_data_args = model_info["data_args"] if "data_args" in model_info else {}
28
- model_info["extra_args"]["output_root"] = "../"
29
- model_info["extra_args"]["neg_audio"] = False
30
- model_info["extra_args"]["image_mixup"] = 0.0
31
-
32
- model = load_trained_model(join(checkpoint_dir, model_info["chkpt_name"]), model_info["extra_args"])
33
- model.set_full_train(True)
34
-
35
- if model.image_model_type == "dinov2":
36
- load_size = cfg.load_size * 2
37
- else:
38
- load_size = cfg.load_size
39
-
40
- if model.image_model_type == "davenet":
41
- batch_size = cfg.batch_size // 2
42
- elif model.image_model_type == "imagebind":
43
- batch_size = cfg.batch_size
44
- else:
45
- batch_size = cfg.batch_size
46
-
47
- print(load_size)
48
-
49
- data_args = dict(
50
- dataset_name=cfg.dataset_name,
51
- load_size=load_size,
52
- image_aug=cfg.image_aug,
53
- audio_aug=cfg.audio_aug,
54
- audio_model_type=model.audio_model_type,
55
- pytorch_data_dir=cfg.pytorch_data_dir,
56
- use_cached_embs=model.use_cached_embs,
57
- batch_size=batch_size,
58
- num_workers=cfg.num_workers,
59
- extra_audio_masking=False,
60
- use_original_val_set=False,
61
- use_extra_val_sets=True,
62
- use_caption=True,
63
- data_for_plotting=False,
64
- n_frames=None,
65
- audio_level=False,
66
- neg_audio=False,
67
- quad_mixup=0.0,
68
- bg_mixup=0.0,
69
- patch_mixup=0.0,
70
- patch_size=8,
71
- )
72
- data_args = {**data_args, **extra_data_args}
73
-
74
- datamodule = AVDataModule(**data_args)
75
- log_dir = join(cfg.output_root, "logs", "evaluate", model_name)
76
- print(log_dir)
77
- tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False)
78
- trainer = Trainer(
79
- accelerator='gpu',
80
- strategy="ddp",
81
- devices=cfg.num_gpus,
82
- logger=tb_logger)
83
- trainer.validate(model, datamodule)
84
-
85
-
86
- if __name__ == "__main__":
87
- my_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/AudioMAE.py DELETED
@@ -1,570 +0,0 @@
1
- import math
2
- import os
3
- import warnings
4
- from functools import partial
5
-
6
- import numpy as np
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import torchaudio
11
- from timm.models.layers import to_2tuple
12
- from torch.utils.data import Dataset
13
- from torchaudio.functional import resample
14
- import pickle
15
-
16
-
17
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
18
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
19
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
20
- def norm_cdf(x):
21
- # Computes standard normal cumulative distribution function
22
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
23
-
24
- if (mean < a - 2 * std) or (mean > b + 2 * std):
25
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
26
- "The distribution of values may be incorrect.",
27
- stacklevel=2)
28
-
29
- with torch.no_grad():
30
- # Values are generated by using a truncated uniform distribution and
31
- # then using the inverse CDF for the normal distribution.
32
- # Get upper and lower cdf values
33
- l = norm_cdf((a - mean) / std)
34
- u = norm_cdf((b - mean) / std)
35
-
36
- # Uniformly fill tensor with values from [l, u], then translate to
37
- # [2l-1, 2u-1].
38
- tensor.uniform_(2 * l - 1, 2 * u - 1)
39
-
40
- # Use inverse cdf transform for normal distribution to get truncated
41
- # standard normal
42
- tensor.erfinv_()
43
-
44
- # Transform to proper mean, std
45
- tensor.mul_(std * math.sqrt(2.))
46
- tensor.add_(mean)
47
-
48
- # Clamp to ensure it's in the proper range
49
- tensor.clamp_(min=a, max=b)
50
- return tensor
51
-
52
-
53
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
54
- # type: (Tensor, float, float, float, float) -> Tensor
55
- r"""Fills the input Tensor with values drawn from a truncated
56
- normal distribution. The values are effectively drawn from the
57
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
58
- with values outside :math:`[a, b]` redrawn until they are within
59
- the bounds. The method used for generating the random values works
60
- best when :math:`a \leq \text{mean} \leq b`.
61
- Args:
62
- tensor: an n-dimensional `torch.Tensor`
63
- mean: the mean of the normal distribution
64
- std: the standard deviation of the normal distribution
65
- a: the minimum cutoff value
66
- b: the maximum cutoff value
67
- Examples:
68
- >>> w = torch.empty(3, 5)
69
- >>> nn.init.trunc_normal_(w)
70
- """
71
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
72
-
73
-
74
- class Mlp(nn.Module):
75
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
76
- super().__init__()
77
- out_features = out_features or in_features
78
- hidden_features = hidden_features or in_features
79
- self.fc1 = nn.Linear(in_features, hidden_features)
80
- self.act = act_layer()
81
- self.fc2 = nn.Linear(hidden_features, out_features)
82
- self.drop = nn.Dropout(drop)
83
-
84
- def forward(self, x):
85
- x = self.fc1(x)
86
- x = self.act(x)
87
- x = self.drop(x)
88
- x = self.fc2(x)
89
- x = self.drop(x)
90
- return x
91
-
92
-
93
- class Attention(nn.Module):
94
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
95
- super().__init__()
96
- self.num_heads = num_heads
97
- head_dim = dim // num_heads
98
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
99
- self.scale = qk_scale or head_dim ** -0.5
100
-
101
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
102
- self.attn_drop = nn.Dropout(attn_drop)
103
- self.proj = nn.Linear(dim, dim)
104
- self.proj_drop = nn.Dropout(proj_drop)
105
-
106
- def forward(self, x):
107
- B, N, C = x.shape
108
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
109
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
110
-
111
- attn = (q @ k.transpose(-2, -1)) * self.scale
112
- attn = attn.softmax(dim=-1)
113
- attn = self.attn_drop(attn)
114
-
115
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
116
- x = self.proj(x)
117
- x = self.proj_drop(x)
118
- return x
119
-
120
-
121
- def drop_path(x, drop_prob: float = 0., training: bool = False):
122
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
123
-
124
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
125
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
126
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
127
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
128
- 'survival rate' as the argument.
129
-
130
- """
131
- if drop_prob == 0. or not training:
132
- return x
133
- keep_prob = 1 - drop_prob
134
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
135
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
136
- random_tensor.floor_() # binarize
137
- output = x.div(keep_prob) * random_tensor
138
- return output
139
-
140
-
141
- class DropPath(nn.Module):
142
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
143
- """
144
-
145
- def __init__(self, drop_prob=None):
146
- super(DropPath, self).__init__()
147
- self.drop_prob = drop_prob
148
-
149
- def forward(self, x):
150
- return drop_path(x, self.drop_prob, self.training)
151
-
152
-
153
- class Block(nn.Module):
154
-
155
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
156
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
157
- super().__init__()
158
- self.norm1 = norm_layer(dim)
159
- self.attn = Attention(
160
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
161
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
- self.norm2 = norm_layer(dim)
164
- mlp_hidden_dim = int(dim * mlp_ratio)
165
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
-
167
- def forward(self, x):
168
- x = x + self.drop_path(self.attn(self.norm1(x)))
169
- x = x + self.drop_path(self.mlp(self.norm2(x)))
170
- return x
171
-
172
-
173
- class PatchEmbed(nn.Module):
174
- """ Image to Patch Embedding
175
- """
176
-
177
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
178
- super().__init__()
179
- img_size = to_2tuple(img_size)
180
- patch_size = to_2tuple(patch_size)
181
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
182
- self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
183
- self.img_size = img_size
184
- self.patch_size = patch_size
185
- self.num_patches = num_patches
186
-
187
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
188
-
189
- def forward(self, x):
190
- B, C, H, W = x.shape
191
- # FIXME look at relaxing size constraints
192
- # assert H == self.img_size[0] and W == self.img_size[1], \
193
- # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
194
- x = self.proj(x).flatten(2).transpose(1, 2)
195
- return x
196
-
197
-
198
- class HybridEmbed(nn.Module):
199
- """ CNN Feature Map Embedding
200
- Extract feature map from CNN, flatten, project to embedding dim.
201
- """
202
-
203
- def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
204
- super().__init__()
205
- assert isinstance(backbone, nn.Module)
206
- img_size = to_2tuple(img_size)
207
- self.img_size = img_size
208
- self.backbone = backbone
209
- if feature_size is None:
210
- with torch.no_grad():
211
- # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
212
- # map for all networks, the feature metadata has reliable channel and stride info, but using
213
- # stride to calc feature dim requires info about padding of each stage that isn't captured.
214
- training = backbone.training
215
- if training:
216
- backbone.eval()
217
- o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
218
- feature_size = o.shape[-2:]
219
- feature_dim = o.shape[1]
220
- backbone.train(training)
221
- else:
222
- feature_size = to_2tuple(feature_size)
223
- feature_dim = self.backbone.feature_info.channels()[-1]
224
- self.num_patches = feature_size[0] * feature_size[1]
225
- self.proj = nn.Linear(feature_dim, embed_dim)
226
-
227
- def forward(self, x):
228
- x = self.backbone(x)[-1]
229
- x = x.flatten(2).transpose(1, 2)
230
- x = self.proj(x)
231
- return x
232
-
233
-
234
- class TimmVisionTransformer(nn.Module):
235
- """ Vision Transformer with support for patch or hybrid CNN input stage
236
- """
237
-
238
- def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
239
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
240
- drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
241
- super().__init__()
242
- self.num_classes = num_classes
243
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
244
-
245
- if hybrid_backbone is not None:
246
- self.patch_embed = HybridEmbed(
247
- hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
248
- else:
249
- self.patch_embed = PatchEmbed(
250
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
251
- num_patches = self.patch_embed.num_patches
252
-
253
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
254
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
255
- self.pos_drop = nn.Dropout(p=drop_rate)
256
-
257
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
258
- self.blocks = nn.ModuleList([
259
- Block(
260
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
261
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
262
- for i in range(depth)])
263
- self.norm = norm_layer(embed_dim)
264
-
265
- # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
266
- # self.repr = nn.Linear(embed_dim, representation_size)
267
- # self.repr_act = nn.Tanh()
268
-
269
- # Classifier head
270
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
271
-
272
- trunc_normal_(self.pos_embed, std=.02)
273
- trunc_normal_(self.cls_token, std=.02)
274
- self.apply(self._init_weights)
275
-
276
- def _init_weights(self, m):
277
- if isinstance(m, nn.Linear):
278
- trunc_normal_(m.weight, std=.02)
279
- if isinstance(m, nn.Linear) and m.bias is not None:
280
- nn.init.constant_(m.bias, 0)
281
- elif isinstance(m, nn.LayerNorm):
282
- nn.init.constant_(m.bias, 0)
283
- nn.init.constant_(m.weight, 1.0)
284
-
285
- @torch.jit.ignore
286
- def no_weight_decay(self):
287
- return {'pos_embed', 'cls_token'}
288
-
289
- def get_classifier(self):
290
- return self.head
291
-
292
- def reset_classifier(self, num_classes, global_pool=''):
293
- self.num_classes = num_classes
294
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
295
-
296
- def forward_features(self, x):
297
- B = x.shape[0]
298
- x = self.patch_embed(x)
299
-
300
- cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
301
- x = torch.cat((cls_tokens, x), dim=1)
302
- x = x + self.pos_embed
303
- x = self.pos_drop(x)
304
-
305
- for blk in self.blocks:
306
- x = blk(x)
307
-
308
- x = self.norm(x)
309
- return x[:, 0]
310
-
311
- def forward(self, x):
312
- x = self.forward_features(x)
313
- x = self.head(x)
314
- return x
315
-
316
-
317
- class VisionTransformer(TimmVisionTransformer):
318
- """ Vision Transformer with support for global average pooling
319
- """
320
-
321
- def __init__(self, **kwargs):
322
- super(VisionTransformer, self).__init__(**kwargs)
323
- norm_layer = kwargs['norm_layer']
324
- embed_dim = kwargs['embed_dim']
325
- self.fc_norm = norm_layer(embed_dim)
326
- del self.norm # remove the original norm
327
-
328
- def interpolate_pos_encoding(self, x, embed):
329
- new_patches = x.shape[1]
330
- old_patches = embed.shape[1]
331
-
332
- w = 8
333
- h = int(new_patches / w)
334
- if new_patches == old_patches:
335
- return embed
336
-
337
- dim = x.shape[-1]
338
- pos_embed = nn.functional.interpolate(
339
- embed.reshape(1, 64, 8, dim).permute(0, 3, 1, 2),
340
- size=(h, w),
341
- mode='bicubic',
342
- )
343
- pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
344
- return pos_embed
345
-
346
- def forward(self, x):
347
- B = x.shape[0]
348
- x = self.patch_embed(x)
349
-
350
- x = x + self.interpolate_pos_encoding(x, self.pos_embed[:, 1:, :])
351
-
352
- cls_token = self.cls_token + self.pos_embed[:, :1, :]
353
- cls_tokens = cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
354
- x = torch.cat((cls_tokens, x), dim=1)
355
- x = self.pos_drop(x)
356
-
357
- for blk in self.blocks:
358
- x = blk(x)
359
-
360
- # x = x[:, 1:, :].mean(dim=1) # global pool without cls token
361
- # outcome = self.fc_norm(x)
362
-
363
- return x[:, 1:, :].reshape(B, -1, 8, 768).permute(0, 3, 2, 1), x[:, 0]
364
-
365
-
366
- class NewPatchEmbed(nn.Module):
367
- """ Flexible Image to Patch Embedding
368
- """
369
-
370
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
371
- super().__init__()
372
- img_size = to_2tuple(img_size)
373
- patch_size = to_2tuple(patch_size)
374
- stride = to_2tuple(stride)
375
- self.img_size = img_size
376
- self.patch_size = patch_size
377
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
378
- _, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
379
- self.patch_hw = (h, w)
380
- self.num_patches = h * w
381
-
382
- def get_output_shape(self, img_size):
383
- # todo: don't be lazy..
384
- return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
385
-
386
- def forward(self, x):
387
- x = self.proj(x)
388
- x = x.flatten(2).transpose(1, 2)
389
- return x
390
-
391
-
392
- def pca(image_feats_list, dim=3, fit_pca=None):
393
- from sklearn.decomposition import PCA
394
-
395
- device = image_feats_list[0].device
396
-
397
- def flatten(tensor, target_size=None):
398
- if target_size is not None and fit_pca is None:
399
- F.interpolate(tensor, (target_size, target_size), mode="bilinear")
400
- B, C, H, W = tensor.shape
401
- return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
402
-
403
- if len(image_feats_list) > 1 and fit_pca is None:
404
- target_size = image_feats_list[0].shape[2]
405
- else:
406
- target_size = None
407
-
408
- flattened_feats = []
409
- for feats in image_feats_list:
410
- flattened_feats.append(flatten(feats, target_size))
411
- x = torch.cat(flattened_feats, dim=0)
412
-
413
- if fit_pca is None:
414
- fit_pca = PCA(n_components=dim, svd_solver="arpack").fit(np.nan_to_num(x.detach().numpy()))
415
-
416
- reduced_feats = []
417
- for feats in image_feats_list:
418
- x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
419
- x_red -= x_red.min(dim=0, keepdim=True).values
420
- x_red /= x_red.max(dim=0, keepdim=True).values
421
- B, C, H, W = feats.shape
422
- reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
423
-
424
- return reduced_feats, fit_pca
425
-
426
-
427
- class AudiosetDataset(Dataset):
428
- def __init__(self, audio_conf):
429
- self.audio_conf = audio_conf
430
- self.melbins = self.audio_conf.get('num_mel_bins')
431
- self.dataset = self.audio_conf.get('dataset')
432
- self.norm_mean = self.audio_conf.get('mean')
433
- self.norm_std = self.audio_conf.get('std')
434
-
435
- print('Dataset: {}, mean {:.3f} and std {:.3f}'.format(self.dataset, self.norm_mean, self.norm_std))
436
- print(f'size of dataset {self.__len__()}')
437
-
438
- def _wav2fbank(self, filename):
439
- sample_rate = 16000
440
- target_length = 10
441
- waveform, obs_sr = torchaudio.load(filename)
442
- waveform = waveform[0]
443
- if obs_sr != sample_rate:
444
- waveform = resample(waveform, obs_sr, sample_rate)
445
-
446
- original_length = waveform.shape[0]
447
- padding = target_length * sample_rate - original_length
448
-
449
- if padding > 0:
450
- m = torch.nn.ZeroPad2d((0, padding))
451
- waveform = m(waveform)
452
- else:
453
- waveform = waveform[:target_length * sample_rate]
454
-
455
-
456
- waveform = waveform - waveform.mean()
457
-
458
- # 498 128, 998, 128
459
- fbank = torchaudio.compliance.kaldi.fbank(
460
- waveform.unsqueeze(0),
461
- htk_compat=True,
462
- sample_frequency=sample_rate,
463
- use_energy=False,
464
- window_type='hanning',
465
- num_mel_bins=128,
466
- dither=0.0,
467
- frame_shift=10)
468
-
469
- normed_fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
470
-
471
- return normed_fbank
472
-
473
- def __getitem__(self, index):
474
- datum = {"wav": "../../samples/example.wav"}
475
- fbank = self._wav2fbank(datum['wav'])
476
- fbank = fbank.transpose(0, 1).unsqueeze(0) # 1, 128, 1024 (...,freq,time)
477
- fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq
478
- # the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128]
479
- return fbank.unsqueeze(0)
480
-
481
- def __len__(self):
482
- return 1
483
-
484
-
485
- class AudioMAE(nn.Module):
486
-
487
- def __init__(self, output_path, finetuned):
488
- super().__init__()
489
- # build model
490
- model = VisionTransformer(
491
- patch_size=16,
492
- embed_dim=768,
493
- depth=12,
494
- num_heads=12,
495
- mlp_ratio=4,
496
- qkv_bias=True,
497
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
498
- num_classes=527,
499
- drop_path_rate=0.1)
500
-
501
- img_size = (1024, 128) # 1024, 128
502
- emb_dim = 768
503
- model.patch_embed = NewPatchEmbed(
504
- img_size=img_size, patch_size=(16, 16), in_chans=1, embed_dim=emb_dim, stride=16)
505
- num_patches = model.patch_embed.num_patches
506
- model.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False)
507
-
508
- if finetuned:
509
- fn = "audiomae_finetuned.pth"
510
- else:
511
- fn = "audiomae.pth"
512
-
513
- checkpoint = torch.load(os.path.join(output_path, 'models', fn), map_location='cpu')
514
-
515
- checkpoint_model = checkpoint['model']
516
- state_dict = model.state_dict()
517
- for k in ['head.weight', 'head.bias']:
518
- if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
519
- print(f"Removing key {k} from pretrained checkpoint")
520
- del checkpoint_model[k]
521
- msg = model.load_state_dict(checkpoint_model, strict=False)
522
- print(msg)
523
-
524
- model = model.eval()
525
- self.model = model
526
- self.config = dict(output_path=output_path, finetuned=finetuned)
527
-
528
- def forward(self, audio, include_cls):
529
- patch_tokens, cls_token = self.model(audio)
530
-
531
- if include_cls:
532
- return patch_tokens, cls_token
533
- else:
534
- return patch_tokens
535
-
536
-
537
- if __name__ == '__main__':
538
- import os
539
-
540
- device = torch.device("cuda:2")
541
-
542
- torch.manual_seed(0)
543
- np.random.seed(0)
544
-
545
- model = AudioMAE("../../", True).to(device)
546
-
547
- audio_conf_val = {
548
- 'num_mel_bins': 128,
549
- 'target_length': 1024,
550
- 'dataset': "audioset",
551
- 'mode': 'val',
552
- 'mean': -4.2677393,
553
- 'std': 4.5689974,
554
- }
555
-
556
- dataset = AudiosetDataset(audio_conf=audio_conf_val)
557
-
558
- batch = dataset[0].unsqueeze(0).to(device)
559
-
560
- embeddings = model(batch, include_cls=False)
561
-
562
- import matplotlib.pyplot as plt
563
-
564
- with torch.no_grad():
565
- [pca_feats], _ = pca([embeddings])
566
- plt.imshow(pca_feats.cpu().squeeze(0).permute(1, 2, 0))
567
- plt.show()
568
- print("here")
569
-
570
- print("here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/CAVMAE.py DELETED
@@ -1,1082 +0,0 @@
1
- import random
2
-
3
- import numpy as np
4
- import timm
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torchaudio
9
- import torchvision.transforms as T
10
- from PIL import Image
11
- from timm.models.layers import to_2tuple, DropPath
12
- from timm.models.vision_transformer import Mlp, PatchEmbed, Block
13
- import os
14
-
15
-
16
- class Attention(nn.Module):
17
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
18
- super().__init__()
19
- self.num_heads = num_heads
20
- head_dim = dim // num_heads
21
- # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
22
- self.scale = qk_scale or head_dim ** -0.5
23
-
24
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
25
- self.attn_drop = nn.Dropout(attn_drop)
26
- self.proj = nn.Linear(dim, dim)
27
- self.proj_drop = nn.Dropout(proj_drop)
28
-
29
- def forward(self, x):
30
- B, N, C = x.shape
31
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
32
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
33
-
34
- attn = (q @ k.transpose(-2, -1)) * self.scale
35
- attn = attn.softmax(dim=-1)
36
- attn = self.attn_drop(attn)
37
-
38
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
39
- x = self.proj(x)
40
- x = self.proj_drop(x)
41
- return x
42
-
43
-
44
- def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size, cls_token=False):
45
- """
46
- grid_size: int of the grid height and width
47
- return:
48
- pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
49
- """
50
- grid_h = np.arange(grid_h_size, dtype=float)
51
- grid_w = np.arange(grid_w_size, dtype=float)
52
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
53
- grid = np.stack(grid, axis=0)
54
-
55
- grid = grid.reshape([2, 1, grid_w_size, grid_h_size])
56
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
57
- if cls_token:
58
- pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
59
- return pos_embed
60
-
61
-
62
- def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
63
- assert embed_dim % 2 == 0
64
-
65
- # use half of dimensions to encode grid_h
66
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
67
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
68
-
69
- emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
70
- return emb
71
-
72
-
73
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
74
- """
75
- embed_dim: output dimension for each position
76
- pos: a list of positions to be encoded: size (M,)
77
- out: (M, D)
78
- """
79
- assert embed_dim % 2 == 0
80
- omega = np.arange(embed_dim // 2, dtype=float)
81
- omega /= embed_dim / 2.
82
- omega = 1. / 10000 ** omega # (D/2,)
83
-
84
- pos = pos.reshape(-1) # (M,)
85
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
86
-
87
- emb_sin = np.sin(out) # (M, D/2)
88
- emb_cos = np.cos(out) # (M, D/2)
89
-
90
- emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
91
- return emb
92
-
93
-
94
- # --------------------------------------------------------
95
- # Interpolate position embeddings for high-resolution
96
- # References:
97
- # DeiT: https://github.com/facebookresearch/deit
98
- # --------------------------------------------------------
99
- def interpolate_pos_embed(model, checkpoint_model):
100
- if 'pos_embed' in checkpoint_model:
101
- pos_embed_checkpoint = checkpoint_model['pos_embed']
102
- embedding_size = pos_embed_checkpoint.shape[-1]
103
- num_patches = model.patch_embed.num_patches
104
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
105
- # height (== width) for the checkpoint position embedding
106
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
107
- # height (== width) for the new position embedding
108
- new_size = int(num_patches ** 0.5)
109
- # class_token and dist_token are kept unchanged
110
- if orig_size != new_size:
111
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
112
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
113
- # only the position tokens are interpolated
114
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
115
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
116
- pos_tokens = torch.nn.functional.interpolate(
117
- pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
118
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
119
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
120
- checkpoint_model['pos_embed'] = new_pos_embed
121
-
122
-
123
- class PatchEmbed(nn.Module):
124
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
125
- super().__init__()
126
-
127
- img_size = to_2tuple(img_size)
128
- patch_size = to_2tuple(patch_size)
129
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
130
- self.img_size = img_size
131
- self.patch_size = patch_size
132
- self.num_patches = num_patches
133
-
134
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
135
-
136
- def forward(self, x):
137
- x = self.proj(x).flatten(2).transpose(1, 2)
138
- return x
139
-
140
-
141
- class Block(nn.Module):
142
- def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
143
- drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
144
- super().__init__()
145
- self.norm1 = norm_layer(dim)
146
- self.norm1_a = norm_layer(dim)
147
- self.norm1_v = norm_layer(dim)
148
- self.attn = Attention(
149
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
150
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
151
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
152
- self.norm2 = norm_layer(dim)
153
- self.norm2_a = norm_layer(dim)
154
- self.norm2_v = norm_layer(dim)
155
- mlp_hidden_dim = int(dim * mlp_ratio)
156
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
157
-
158
- def forward(self, x, modality=None):
159
- if modality == None:
160
- x = x + self.drop_path(self.attn(self.norm1(x)))
161
- x = x + self.drop_path(self.mlp(self.norm2(x)))
162
- elif modality == 'a':
163
- x = x + self.drop_path(self.attn(self.norm1_a(x)))
164
- x = x + self.drop_path(self.mlp(self.norm2_a(x)))
165
- elif modality == 'v':
166
- x = x + self.drop_path(self.attn(self.norm1_v(x)))
167
- x = x + self.drop_path(self.mlp(self.norm2_v(x)))
168
- return x
169
-
170
-
171
- # our main proposed model, for pretraining only, for finetuning, use CAVMAEFT class
172
- class CAVMAE(nn.Module):
173
- """ CAV-MAE Model
174
- """
175
-
176
- def __init__(self, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
177
- embed_dim=768, modality_specific_depth=11, num_heads=12,
178
- decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
179
- mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=False):
180
- super().__init__()
181
- print('A CAV-MAE Model')
182
- print('Use norm_pix_loss: ', norm_pix_loss)
183
- print('Learnable Positional Embedding: ', tr_pos)
184
-
185
- # the encoder part
186
- # overide the timm package
187
- timm.models.vision_transformer.PatchEmbed = PatchEmbed
188
- timm.models.vision_transformer.Block = Block
189
-
190
- self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
191
- self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
192
-
193
- self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
194
- print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
195
- self.patch_embed_v.num_patches))
196
-
197
- self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
198
- self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
199
-
200
- self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
201
- requires_grad=tr_pos) # fixed sin-cos embedding
202
- self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
203
- requires_grad=tr_pos) # fixed sin-cos embedding
204
-
205
- # audio-branch
206
- self.blocks_a = nn.ModuleList(
207
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
208
- range(modality_specific_depth)])
209
- # visual-branch
210
- self.blocks_v = nn.ModuleList(
211
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
212
- range(modality_specific_depth)])
213
- # unified branch
214
- self.blocks_u = nn.ModuleList(
215
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
216
- range(12 - modality_specific_depth)])
217
-
218
- # independent normalization layer for audio, visual, and audio-visual
219
- self.norm_a, self.norm_v, self.norm = norm_layer(embed_dim), norm_layer(embed_dim), norm_layer(embed_dim)
220
-
221
- # the decoder part
222
- # Project to lower dimension for the decoder
223
- self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
224
-
225
- # token used for masking
226
- self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
227
-
228
- self.decoder_modality_a = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
229
- self.decoder_modality_v = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
230
-
231
- self.decoder_pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, decoder_embed_dim),
232
- requires_grad=tr_pos) # fixed sin-cos embedding
233
- self.decoder_pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, decoder_embed_dim),
234
- requires_grad=tr_pos) # fixed sin-cos embedding
235
-
236
- self.decoder_blocks = nn.ModuleList(
237
- [Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
238
- for i in range(decoder_depth)])
239
-
240
- self.decoder_norm = norm_layer(decoder_embed_dim)
241
-
242
- # project channel is different for two modality, use two projection head
243
- self.decoder_pred_a = nn.Linear(decoder_embed_dim, patch_size ** 2 * 1, bias=True) # decoder to patch
244
- self.decoder_pred_v = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
245
-
246
- self.norm_pix_loss = norm_pix_loss
247
-
248
- self.initialize_weights()
249
-
250
- print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
251
- print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
252
-
253
- def initialize_weights(self):
254
- # initialize (and freeze) pos_embed by sin-cos embedding, opt the cls token, add by myself
255
- pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
256
- cls_token=False)
257
- self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
258
-
259
- pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
260
- int(self.patch_embed_v.num_patches ** .5), cls_token=False)
261
- self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
262
-
263
- decoder_pos_embed_a = get_2d_sincos_pos_embed(self.decoder_pos_embed_a.shape[-1], 8,
264
- int(self.patch_embed_a.num_patches / 8), cls_token=False)
265
- self.decoder_pos_embed_a.data.copy_(torch.from_numpy(decoder_pos_embed_a).float().unsqueeze(0))
266
-
267
- decoder_pos_embed_v = get_2d_sincos_pos_embed(self.decoder_pos_embed_v.shape[-1],
268
- int(self.patch_embed_v.num_patches ** .5),
269
- int(self.patch_embed_v.num_patches ** .5), cls_token=False)
270
- self.decoder_pos_embed_v.data.copy_(torch.from_numpy(decoder_pos_embed_v).float().unsqueeze(0))
271
-
272
- # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
273
- w = self.patch_embed_a.proj.weight.data
274
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
275
- w = self.patch_embed_v.proj.weight.data
276
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
277
-
278
- # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
279
- torch.nn.init.normal_(self.modality_a, std=.02)
280
- torch.nn.init.normal_(self.modality_v, std=.02)
281
- torch.nn.init.normal_(self.decoder_modality_a, std=.02)
282
- torch.nn.init.normal_(self.decoder_modality_v, std=.02)
283
- torch.nn.init.normal_(self.mask_token, std=.02)
284
-
285
- # initialize nn.Linear and nn.LayerNorm
286
- self.apply(self._init_weights)
287
-
288
- def _init_weights(self, m):
289
- if isinstance(m, nn.Linear):
290
- # we use xavier_uniform following official JAX ViT:
291
- torch.nn.init.xavier_uniform_(m.weight)
292
- if isinstance(m, nn.Linear) and m.bias is not None:
293
- nn.init.constant_(m.bias, 0)
294
- elif isinstance(m, nn.LayerNorm):
295
- nn.init.constant_(m.bias, 0)
296
- nn.init.constant_(m.weight, 1.0)
297
-
298
- def patchify(self, imgs, c, h, w, p=16):
299
- """
300
- imgs: (N, 3, H, W)
301
- x: (N, L, patch_size**2 *3)
302
- """
303
- x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
304
- x = torch.einsum('nchpwq->nhwpqc', x)
305
- x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * c))
306
- return x
307
-
308
- def unpatchify(self, x, c, h, w, p=16):
309
- """
310
- x: (N, L, patch_size**2 *3)
311
- imgs: (N, 3, H, W)
312
- """
313
- assert h * w == x.shape[1]
314
-
315
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
316
- x = torch.einsum('nhwpqc->nchpwq', x)
317
- imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
318
- return imgs
319
-
320
- def random_masking_unstructured(self, x, mask_ratio):
321
- """
322
- Perform per-sample random masking by per-sample shuffling.
323
- Per-sample shuffling is done by argsort random noise.
324
- x: [N, L, D], sequence
325
- """
326
- N, L, D = x.shape # batch, length, dim
327
- len_keep = int(L * (1 - mask_ratio))
328
-
329
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
330
-
331
- # sort noise for each sample
332
- ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
333
- ids_restore = torch.argsort(ids_shuffle, dim=1)
334
-
335
- # keep the first subset
336
- ids_keep = ids_shuffle[:, :len_keep]
337
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
338
-
339
- # generate the binary mask: 0 is keep, 1 is remove
340
- mask = torch.ones([N, L], device=x.device)
341
- mask[:, :len_keep] = 0
342
- # unshuffle to get the binary mask
343
- mask = torch.gather(mask, dim=1, index=ids_restore)
344
-
345
- return x_masked, mask, ids_restore
346
-
347
- def random_masking_structured(self, x, mask_ratio, t=64, f=8, mode='time'):
348
- """
349
- Perform per-sample random masking by per-sample shuffling.
350
- Per-sample shuffling is done by argsort random noise.
351
- x: [N, L, D], sequence
352
- """
353
- N, L, D = x.shape # batch, length, dim
354
- len_keep = int(L * (1 - mask_ratio))
355
-
356
- noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
357
- assert L == f * t
358
- noise = noise.reshape(N, f, t) # the audio patch is in shape [f,t], not [t,f]
359
- if mode == 'time':
360
- for i in range(N):
361
- mask_t_list = random.sample(range(t), int(t * mask_ratio))
362
- for k in mask_t_list:
363
- noise[i, :, k] = 1.1 # large value will be removed
364
- elif mode == 'freq':
365
- for i in range(N):
366
- mask_f_list = random.sample(range(f), int(f * mask_ratio))
367
- for k in mask_f_list:
368
- noise[i, k, :] = 1.1 # large value will be removed
369
- elif mode == 'tf':
370
- for i in range(N):
371
- mask_t_list = random.sample(range(t), int(t * mask_ratio * 0.7))
372
- for k in mask_t_list:
373
- noise[i, :, k] = 1.1 # large value will be removed
374
- for i in range(N):
375
- mask_f_list = random.sample(range(f), int(f * mask_ratio * 0.7))
376
- for k in mask_f_list:
377
- noise[i, k, :] = 1.1 # large value will be removed
378
- noise = noise.reshape(N, L)
379
-
380
- # sort noise for each sample, only need to manuplate these two ids_shuffle, ids_restore
381
- ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
382
- ids_restore = torch.argsort(ids_shuffle, dim=1)
383
-
384
- # keep the first subset
385
- ids_keep = ids_shuffle[:, :len_keep]
386
- x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
387
-
388
- # generate the binary mask: 0 is keep, 1 is remove
389
- mask = torch.ones([N, L], device=x.device)
390
- mask[:, :len_keep] = 0
391
- # unshuffle to get the binary mask
392
- mask = torch.gather(mask, dim=1, index=ids_restore)
393
-
394
- return x_masked, mask, ids_restore
395
-
396
- def forward_encoder(self, a, v, mask_ratio_a, mask_ratio_v, mask_mode='unstructured'):
397
- # embed patches
398
- a = a.unsqueeze(1)
399
- a = a.transpose(2, 3)
400
- a = self.patch_embed_a(a)
401
- a = a + self.pos_embed_a
402
- a = a + self.modality_a
403
-
404
- v = self.patch_embed_v(v)
405
- v = v + self.pos_embed_v
406
- v = v + self.modality_v
407
-
408
- # by default, we always use unstructured masking
409
- if mask_mode == 'unstructured':
410
- a, mask_a, ids_restore_a = self.random_masking_unstructured(a, mask_ratio_a)
411
- # in ablation study, we tried time/freq/tf masking. mode in ['freq', 'time', 'tf']
412
- else:
413
- a, mask_a, ids_restore_a = self.random_masking_structured(a, mask_ratio_a, t=64, f=8, mode=mask_mode)
414
-
415
- # visual branch always use unstructured masking
416
- v, mask_v, ids_restore_v = self.random_masking_unstructured(v, mask_ratio_v)
417
-
418
- # audio and visual stream, independent blocks
419
- for blk in self.blocks_a:
420
- a = blk(a)
421
-
422
- for blk in self.blocks_v:
423
- v = blk(v)
424
-
425
- x = torch.cat((a, v), dim=1)
426
-
427
- # unified stream, shared blocks_u, but independent normalization layers
428
- for blk in self.blocks_u:
429
- x = blk(x)
430
- x = self.norm(x)
431
-
432
- for blk in self.blocks_u:
433
- ca = blk(a, 'a')
434
- ca = self.norm_a(ca)
435
-
436
- for blk in self.blocks_u:
437
- cv = blk(v, 'v')
438
- cv = self.norm_v(cv)
439
-
440
- return x, mask_a, ids_restore_a, mask_v, ids_restore_v, ca, cv
441
-
442
- def forward_decoder(self, x, mask_a, ids_restore_a, mask_v, ids_restore_v):
443
-
444
- x = self.decoder_embed(x)
445
-
446
- # append mask tokens to sequence
447
- # mask_tokens_a in shape [B, #a_mask_token, mask_token_dim], get the number of masked samples from mask_a[0], which is the first example of the batch, all samples should have same number of masked tokens
448
- mask_tokens_a = self.mask_token.repeat(x.shape[0], int(mask_a[0].sum()), 1)
449
- a_ = torch.cat([x[:, :self.patch_embed_a.num_patches - int(mask_a[0].sum()), :], mask_tokens_a],
450
- dim=1) # no cls token
451
- a_ = torch.gather(a_, dim=1, index=ids_restore_a.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
452
-
453
- # similar for the visual modality
454
- mask_tokens_v = self.mask_token.repeat(x.shape[0], int(mask_v[0].sum()), 1)
455
- v_ = torch.cat([x[:, self.patch_embed_a.num_patches - int(mask_a[0].sum()):, :], mask_tokens_v],
456
- dim=1) # no cls token
457
- v_ = torch.gather(v_, dim=1, index=ids_restore_v.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
458
-
459
- # concatenate audio and visual tokens
460
- x = torch.cat([a_, v_], dim=1)
461
-
462
- decoder_pos_embed = torch.cat([self.decoder_pos_embed_a, self.decoder_pos_embed_v], dim=1)
463
- x = x + decoder_pos_embed
464
-
465
- # add modality indication tokens
466
- x[:, 0:self.patch_embed_a.num_patches, :] = x[:, 0:self.patch_embed_a.num_patches, :] + self.decoder_modality_a
467
- x[:, self.patch_embed_a.num_patches:, :] = x[:, self.patch_embed_a.num_patches:, :] + self.decoder_modality_v
468
-
469
- # apply Transformer blocks
470
- for blk in self.decoder_blocks:
471
- x = blk(x)
472
- x = self.decoder_norm(x)
473
-
474
- # predictor projection
475
- x_a = self.decoder_pred_a(x[:, :self.patch_embed_a.num_patches, :])
476
- x_v = self.decoder_pred_v(x[:, self.patch_embed_a.num_patches:, :])
477
-
478
- # return audio and video tokens
479
- return x_a, x_v
480
-
481
- def forward_contrastive(self, audio_rep, video_rep, bidirect_contrast=False):
482
- # calculate nce loss for mean-visual representation and mean-audio representation
483
-
484
- audio_rep = torch.nn.functional.normalize(audio_rep, dim=-1)
485
- video_rep = torch.nn.functional.normalize(video_rep, dim=-1)
486
-
487
- total = torch.mm(audio_rep, torch.transpose(video_rep, 0, 1)) / 0.05
488
-
489
- # by default we use single directional
490
- if bidirect_contrast == False:
491
- nce = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
492
- c_acc = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
493
- torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
494
- return nce, c_acc
495
- else:
496
- nce_1 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
497
- nce_2 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total.t(), dim=0)))
498
- c_acc_1 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
499
- torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
500
- c_acc_2 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total.t(), dim=0), dim=0),
501
- torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
502
- nce = (nce_1 + nce_2) / 2
503
- c_acc = (c_acc_1 + c_acc_2) / 2
504
- return nce, c_acc
505
-
506
- def forward_mae_loss(self, input, pred, mask, modality):
507
- if modality == 'a':
508
- # for audio, need to adjust the shape
509
- input = input.unsqueeze(1)
510
- input = input.transpose(2, 3)
511
- target = self.patchify(input, 1, int(input.shape[2] / self.patch_embed_a.patch_size[0]),
512
- int(input.shape[3] / self.patch_embed_a.patch_size[1]), 16)
513
- elif modality == 'v':
514
- target = self.patchify(input, 3, int(input.shape[2] / self.patch_embed_v.patch_size[0]),
515
- int(input.shape[3] / self.patch_embed_v.patch_size[1]), 16)
516
-
517
- # patch-wise normalization might minorly improve the classification performance, but will make the model lose inpainting function
518
- if self.norm_pix_loss:
519
- mean = target.mean(dim=-1, keepdim=True)
520
- var = target.var(dim=-1, keepdim=True)
521
- target = (target - mean) / (var + 1.e-6) ** .5
522
-
523
- loss = (pred - target) ** 2
524
- loss = loss.mean(dim=-1) # [N, L], mean loss per patch
525
-
526
- loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
527
- return loss
528
-
529
- def forward(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mae_loss_weight=1., contrast_loss_weight=0.01,
530
- mask_mode='unstructured'):
531
- # latent is used for reconstruction (mae), latent_c_{a,v} are used for contrastive learning
532
- latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
533
- mask_ratio_a,
534
- mask_ratio_v,
535
- mask_mode=mask_mode)
536
- # if mae loss is used
537
- if mae_loss_weight != 0:
538
- pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v)
539
- loss_mae_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
540
- loss_mae_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
541
- loss_mae = mae_loss_weight * (loss_mae_a + loss_mae_v)
542
- else:
543
- loss_mae_a, loss_mae_v, loss_mae = torch.tensor(0.0, device=audio.device), torch.tensor(0.0,
544
- device=audio.device), torch.tensor(
545
- 0.0, device=audio.device)
546
-
547
- # if contrastive loss is used
548
- if contrast_loss_weight != 0:
549
- # note this is single directional
550
- loss_c, c_acc = self.forward_contrastive(latent_c_a.mean(dim=1), latent_c_v.mean(dim=1))
551
- loss_c = contrast_loss_weight * loss_c
552
- else:
553
- loss_c, c_acc = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device)
554
-
555
- loss = loss_mae + loss_c
556
-
557
- return loss, loss_mae, loss_mae_a, loss_mae_v, loss_c, mask_a, mask_v, c_acc
558
-
559
- # used only for inpainting, ignore if inpainting is not of interest
560
- def forward_inpaint(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mask_mode='unstructured'):
561
- latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
562
- mask_ratio_a,
563
- mask_ratio_v,
564
- mask_mode=mask_mode)
565
- pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) # [N, L, p*p*3]
566
- loss_pixel_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
567
- loss_pixel_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
568
- return pred_a, pred_v, mask_a, mask_v, loss_pixel_a, loss_pixel_v
569
-
570
- # used for retrieval, ignore if retrieval is not of interest
571
- def forward_feat(self, a, v):
572
- # embed patches
573
- a = a.unsqueeze(1)
574
- a = a.transpose(2, 3)
575
- a = self.patch_embed_a(a)
576
- a = a + self.pos_embed_a
577
- a = a + self.modality_a
578
-
579
- v = self.patch_embed_v(v)
580
- v = v + self.pos_embed_v
581
- v = v + self.modality_v
582
-
583
- # the modality-specific stream
584
- for blk in self.blocks_a:
585
- a = blk(a)
586
-
587
- for blk in self.blocks_v:
588
- v = blk(v)
589
-
590
- # use modality specific normalization,
591
- for blk in self.blocks_u:
592
- a = blk(a, 'a')
593
- a = self.norm_a(a)
594
-
595
- for blk in self.blocks_u:
596
- v = blk(v, 'v')
597
- v = self.norm_v(v)
598
- return a, v
599
-
600
- def forward_audio(self, a):
601
- # embed patches
602
- a = a.unsqueeze(1)
603
- a = a.transpose(2, 3)
604
- a = self.patch_embed_a(a)
605
- a = a + self.pos_embed_a
606
- a = a + self.modality_a
607
-
608
- # the modality-specific stream
609
- for blk in self.blocks_a:
610
- a = blk(a)
611
-
612
- # use modality specific normalization,
613
- for blk in self.blocks_u:
614
- a = blk(a, 'a')
615
- a = self.norm_a(a)
616
-
617
- return a.reshape(a.shape[0], 128 // 16, 1024 // 16, 768).permute(0, 3, 1, 2)
618
-
619
- def forward_video(self, v):
620
- v = self.patch_embed_v(v)
621
- v = v + self.pos_embed_v
622
- v = v + self.modality_v
623
-
624
- for blk in self.blocks_v:
625
- v = blk(v)
626
-
627
- for blk in self.blocks_u:
628
- v = blk(v, 'v')
629
- v = self.norm_v(v)
630
- return v.reshape(v.shape[0], 224 // 16, 224 // 16, 768).permute(0, 3, 1, 2)
631
-
632
-
633
- # the finetuned CAV-MAE model
634
- class CAVMAEFT(nn.Module):
635
- def __init__(self, label_dim, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
636
- embed_dim=768, modality_specific_depth=11, num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm,
637
- norm_pix_loss=False, tr_pos=True):
638
- super().__init__()
639
- timm.models.vision_transformer.Block = Block
640
- print('Use norm_pix_loss: ', norm_pix_loss)
641
-
642
- timm.models.vision_transformer.PatchEmbed = PatchEmbed
643
- timm.models.vision_transformer.Block = Block
644
-
645
- self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
646
- self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
647
-
648
- self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
649
- print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
650
- self.patch_embed_v.num_patches))
651
-
652
- self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
653
- self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
654
-
655
- self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
656
- requires_grad=tr_pos) # fixed sin-cos embedding
657
- self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
658
- requires_grad=tr_pos) # fixed sin-cos embedding
659
-
660
- self.blocks_a = nn.ModuleList(
661
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
662
- range(modality_specific_depth)])
663
- self.blocks_v = nn.ModuleList(
664
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
665
- range(modality_specific_depth)])
666
- self.blocks_u = nn.ModuleList(
667
- [Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
668
- range(12 - modality_specific_depth)])
669
-
670
- self.norm_a = norm_layer(embed_dim)
671
- self.norm_v = norm_layer(embed_dim)
672
- self.norm = norm_layer(embed_dim)
673
-
674
- self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, label_dim))
675
-
676
- self.initialize_weights()
677
-
678
- print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
679
- print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
680
-
681
- def get_patch_num(self, input_shape, stride):
682
- test_input = torch.zeros(1, 1, input_shape[0], input_shape[1])
683
- test_proj = torch.nn.Conv2d(1, 4, kernel_size=(16, 16), stride=(stride, stride))
684
- test_output = test_proj(test_input)
685
- print(test_output.shape)
686
- return test_output.shape[2], test_output[3], test_output[2] * test_output[2]
687
-
688
- def initialize_weights(self):
689
- pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
690
- cls_token=False)
691
- self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
692
-
693
- pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
694
- int(self.patch_embed_v.num_patches ** .5), cls_token=False)
695
- self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
696
-
697
- w = self.patch_embed_a.proj.weight.data
698
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
699
- w = self.patch_embed_v.proj.weight.data
700
- torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
701
-
702
- torch.nn.init.normal_(self.modality_a, std=.02)
703
- torch.nn.init.normal_(self.modality_v, std=.02)
704
-
705
- self.apply(self._init_weights)
706
-
707
- def _init_weights(self, m):
708
- if isinstance(m, nn.Linear):
709
- # we use xavier_uniform following official JAX ViT:
710
- torch.nn.init.xavier_uniform_(m.weight)
711
- if isinstance(m, nn.Linear) and m.bias is not None:
712
- nn.init.constant_(m.bias, 0)
713
- elif isinstance(m, nn.LayerNorm):
714
- nn.init.constant_(m.bias, 0)
715
- nn.init.constant_(m.weight, 1.0)
716
-
717
- def forward(self, a, v, mode):
718
- # multi-modal fine-tuning, our default method for fine-tuning
719
- if mode == 'multimodal':
720
- a = a.unsqueeze(1)
721
- a = a.transpose(2, 3)
722
- a = self.patch_embed_a(a)
723
- a = a + self.pos_embed_a
724
- a = a + self.modality_a
725
-
726
- v = self.patch_embed_v(v)
727
- v = v + self.pos_embed_v
728
- v = v + self.modality_v
729
-
730
- for blk in self.blocks_a:
731
- a = blk(a)
732
-
733
- for blk in self.blocks_v:
734
- v = blk(v)
735
-
736
- x = torch.cat((a, v), dim=1)
737
-
738
- for blk in self.blocks_u:
739
- x = blk(x)
740
- x = self.norm(x)
741
-
742
- x = x.mean(dim=1)
743
- x = self.mlp_head(x)
744
- return x
745
-
746
- # finetune with only audio (and inference with only audio when the model is finetuned with only audio)
747
- elif mode == 'audioonly':
748
- a = a.unsqueeze(1)
749
- a = a.transpose(2, 3)
750
- a = self.patch_embed_a(a)
751
- a = a + self.pos_embed_a
752
- a = a + self.modality_a
753
-
754
- for blk in self.blocks_a:
755
- a = blk(a)
756
-
757
- # note here uses the 'a' normalization, it is used in both training and inference, so it is fine
758
- for blk in self.blocks_u:
759
- a = blk(a, 'a')
760
- a = self.norm_a(a)
761
- x = a.mean(dim=1)
762
- x = self.mlp_head(x)
763
- return x
764
-
765
- # finetune with only image (and inference with only audio when the model is finetuned with only image)
766
- elif mode == 'videoonly':
767
- v = self.patch_embed_v(v)
768
- v = v + self.pos_embed_v
769
- v = v + self.modality_v
770
-
771
- for blk in self.blocks_v:
772
- v = blk(v)
773
-
774
- # note here uses the 'v' normalization, it is used in both training and inference, so it is fine
775
- for blk in self.blocks_u:
776
- v = blk(v, 'v')
777
- v = self.norm_v(v)
778
- x = v.mean(dim=1)
779
- x = self.mlp_head(x)
780
- return x
781
-
782
- # used in case that the model is finetuned with both modality, but in inference only audio is given
783
- elif mode == 'missingaudioonly':
784
- a = a.unsqueeze(1)
785
- a = a.transpose(2, 3)
786
- a = self.patch_embed_a(a)
787
- a = a + self.pos_embed_a
788
- a = a + self.modality_a
789
-
790
- for blk in self.blocks_a:
791
- a = blk(a)
792
-
793
- # two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
794
- u = a
795
- for blk in self.blocks_u:
796
- u = blk(u) # note here use unified normalization
797
- u = self.norm(u)
798
- u = u.mean(dim=1)
799
-
800
- for blk in self.blocks_u:
801
- a = blk(a, 'a') # note here use modality-specific normalization
802
- a = self.norm_a(a)
803
- a = a.mean(dim=1)
804
-
805
- # average the output of the two forward passes
806
- x = (u + a) / 2
807
- x = self.mlp_head(x)
808
- return x
809
-
810
- # used in case that the model is fine-tuned with both modality, but in inference only image is given
811
- elif mode == 'missingvideoonly':
812
- v = self.patch_embed_v(v)
813
- v = v + self.pos_embed_v
814
- v = v + self.modality_v
815
-
816
- for blk in self.blocks_v:
817
- v = blk(v)
818
-
819
- # two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
820
- u = v
821
- for blk in self.blocks_u:
822
- u = blk(u) # note here use unified normalization
823
- u = self.norm(u)
824
- u = u.mean(dim=1)
825
-
826
- for blk in self.blocks_u:
827
- v = blk(v, 'v') # note here use modality-specific normalization
828
- v = self.norm_v(v)
829
- v = v.mean(dim=1)
830
-
831
- # average the output of the two forward passes
832
- x = (u + v) / 2
833
- x = self.mlp_head(x)
834
- return x
835
-
836
- # for retrieval
837
- def forward_feat(self, a, v, mode='av'):
838
- # return both audio and visual
839
- if mode == 'av':
840
- a = a.unsqueeze(1)
841
- a = a.transpose(2, 3)
842
- a = self.patch_embed_a(a)
843
- a = a + self.pos_embed_a
844
- a = a + self.modality_a
845
-
846
- v = self.patch_embed_v(v)
847
- v = v + self.pos_embed_v
848
- v = v + self.modality_v
849
-
850
- for blk in self.blocks_a:
851
- a = blk(a)
852
-
853
- for blk in self.blocks_v:
854
- v = blk(v)
855
-
856
- for blk in self.blocks_u:
857
- a = blk(a, 'a')
858
- a = self.norm_a(a)
859
-
860
- for blk in self.blocks_u:
861
- v = blk(v, 'v')
862
-
863
- v = self.norm_v(v)
864
- return a, v
865
-
866
- # return only audio
867
- if mode == 'a':
868
- a = a.unsqueeze(1)
869
- a = a.transpose(2, 3)
870
- a = self.patch_embed_a(a)
871
- a = a + self.pos_embed_a
872
- a = a + self.modality_a
873
-
874
- for blk in self.blocks_a:
875
- a = blk(a)
876
-
877
- for blk in self.blocks_u:
878
- a = blk(a, 'a')
879
-
880
- a = self.norm_a(a)
881
- return a
882
-
883
-
884
- def _wav2fbank(filename):
885
- waveform, sr = torchaudio.load(filename)
886
- waveform = torchaudio.functional.resample(
887
- waveform, orig_freq=sr, new_freq=16000
888
- )
889
-
890
- waveform = waveform - waveform.mean()
891
- waveform
892
- print(sr)
893
-
894
- fbank = torchaudio.compliance.kaldi.fbank(
895
- waveform,
896
- htk_compat=True,
897
- sample_frequency=sr,
898
- use_energy=False,
899
- window_type='hanning',
900
- num_mel_bins=128,
901
- dither=0.0,
902
- frame_shift=10)
903
-
904
- target_length = 1024
905
- n_frames = fbank.shape[0]
906
-
907
- p = target_length - n_frames
908
-
909
- # cut and pad
910
- if p > 0:
911
- m = torch.nn.ZeroPad2d((0, 0, 0, p))
912
- fbank = m(fbank)
913
- elif p < 0:
914
- fbank = fbank[0:target_length, :]
915
-
916
- return fbank
917
-
918
-
919
- def pca(image_feats_list, dim=3, fit_pca=None):
920
- from sklearn.decomposition import PCA
921
-
922
- device = image_feats_list[0].device
923
-
924
- def flatten(tensor, target_size=None):
925
- if target_size is not None and fit_pca is None:
926
- F.interpolate(tensor, (target_size, target_size), mode="bilinear")
927
- B, C, H, W = tensor.shape
928
- return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
929
-
930
- if len(image_feats_list) > 1 and fit_pca is None:
931
- target_size = image_feats_list[0].shape[2]
932
- else:
933
- target_size = None
934
-
935
- flattened_feats = []
936
- for feats in image_feats_list:
937
- flattened_feats.append(flatten(feats, target_size))
938
- x = torch.cat(flattened_feats, dim=0)
939
-
940
- if fit_pca is None:
941
- fit_pca = PCA(n_components=dim).fit(x)
942
-
943
- reduced_feats = []
944
- for feats in image_feats_list:
945
- x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
946
- x_red -= x_red.min(dim=0, keepdim=True).values
947
- x_red /= x_red.max(dim=0, keepdim=True).values
948
- B, C, H, W = feats.shape
949
- reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
950
-
951
- return reduced_feats, fit_pca
952
-
953
-
954
- class CAVMAEAudioFeaturizer(nn.Module):
955
-
956
- def __init__(self, output_path, model_name="base", model=None):
957
- super().__init__()
958
- if model is not None:
959
- self.model = model
960
- else:
961
- if model_name == "base":
962
- model_path = os.path.join(output_path, 'models/audio_model.21.pth')
963
- else:
964
- raise ValueError(f"Unknown model type {model_name}")
965
-
966
- audio_model = CAVMAE(
967
- audio_length=1024,
968
- modality_specific_depth=11,
969
- norm_pix_loss=True,
970
- tr_pos=False)
971
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
972
- mdl_weight = torch.load(model_path, map_location=device)
973
- audio_model = torch.nn.DataParallel(audio_model)
974
- audio_model.load_state_dict(mdl_weight, strict=True)
975
- self.model = audio_model.module.cuda()
976
-
977
- def forward(self, audio, include_cls):
978
- cls_token = None
979
- patch_tokens = self.model.forward_audio(audio.squeeze(1))
980
-
981
- if include_cls:
982
- return patch_tokens, cls_token
983
- else:
984
- return patch_tokens
985
-
986
-
987
- class CAVMAEImageFeaturizer(nn.Module):
988
-
989
- def __init__(self, output_path, model=None, model_name="base"):
990
- super().__init__()
991
- if model is not None:
992
- self.model: CAVMAE = model
993
- else:
994
- if model_name == "base":
995
- model_path = os.path.join(output_path, 'models/audio_model.21.pth')
996
- else:
997
- raise ValueError(f"Unknown model type {model_name}")
998
-
999
- audio_model = CAVMAE(
1000
- audio_length=1024,
1001
- modality_specific_depth=11,
1002
- norm_pix_loss=True,
1003
- tr_pos=False)
1004
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1005
- mdl_weight = torch.load(model_path, map_location=device)
1006
- audio_model = torch.nn.DataParallel(audio_model)
1007
- audio_model.load_state_dict(mdl_weight, strict=True)
1008
- self.model: CAVMAE = audio_model.module.cuda()
1009
-
1010
- def forward(self, image, include_cls):
1011
- cls_token = None
1012
- patch_tokens = self.model.forward_video(image)
1013
-
1014
- if include_cls:
1015
- return patch_tokens, cls_token
1016
- else:
1017
- return patch_tokens
1018
-
1019
-
1020
- if __name__ == "__main__":
1021
- model_path = os.path.join("../../", 'models/audio_model.21.pth')
1022
- audio_model = CAVMAE(
1023
- audio_length=1024,
1024
- modality_specific_depth=11,
1025
- norm_pix_loss=True,
1026
- tr_pos=False)
1027
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1028
- mdl_weight = torch.load(model_path, map_location=device)
1029
- audio_model = torch.nn.DataParallel(audio_model)
1030
- audio_model.load_state_dict(mdl_weight, strict=True)
1031
- model: CAVMAE = audio_model.module.cuda()
1032
-
1033
- image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
1034
- audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
1035
-
1036
- images = []
1037
- audios = []
1038
-
1039
- for image_path in image_paths:
1040
- image = Image.open(image_path).convert("RGB")
1041
- preprocess = T.Compose([
1042
- T.Resize(224, interpolation=Image.BICUBIC),
1043
- T.CenterCrop(224),
1044
- T.ToTensor(),
1045
- T.Normalize(
1046
- mean=[0.4850, 0.4560, 0.4060],
1047
- std=[0.2290, 0.2240, 0.2250]
1048
- )])
1049
- images.append(preprocess(image).unsqueeze(0).cuda())
1050
-
1051
- for audio_path in audio_paths:
1052
- a = _wav2fbank(audio_path).cuda().unsqueeze(0)
1053
- a = (a + 5.081) / (4.4849)
1054
- audios.append(a)
1055
-
1056
- audio_feats, image_feats = model.forward_feat(
1057
- torch.cat(audios, dim=0), torch.cat(images, dim=0))
1058
-
1059
- audio_feats = F.normalize(audio_feats.mean(1), dim=1)
1060
- image_feats = F.normalize(image_feats.mean(1), dim=1)
1061
-
1062
- sims = torch.einsum("bc,dc->bd", image_feats, audio_feats)
1063
- print(sims)
1064
-
1065
- print("here")
1066
-
1067
- # a_feat = F.normalize(a_feat, dim=1)
1068
- # v_feat = F.normalize(v_feat, dim=1)
1069
-
1070
- # [red_v_feat, red_a_feat], fit_pca = pca([v_feat, a_feat])
1071
- #
1072
- # [red_v_feat], fit_pca = pca([v_feat])
1073
- # [red_a_feat], fit_pca = pca([a_feat])
1074
- #
1075
- # import matplotlib.pyplot as plt
1076
- #
1077
- # fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5))
1078
- # ax[0].imshow(red_v_feat[0].permute(1, 2, 0).cpu())
1079
- # ax[1].imshow(red_a_feat[0].permute(1, 2, 0).cpu())
1080
- # plt.tight_layout()
1081
- # plt.show()
1082
- # print("here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/CLIP.py DELETED
@@ -1,50 +0,0 @@
1
- import clip
2
- import torch
3
- from torch import nn
4
-
5
-
6
- class CLIPFeaturizer(nn.Module):
7
-
8
- def __init__(self):
9
- super().__init__()
10
- self.model, self.preprocess = clip.load("ViT-B/16", device="cpu")
11
- self.model.eval().cuda()
12
- self.config = {}
13
-
14
- def get_cls_token(self, img):
15
- return self.model.encode_image(img).to(torch.float32)
16
-
17
- def forward(self, img, include_cls):
18
- features = self.model.get_visual_features(img, include_cls)
19
- new_features = []
20
- for i in range(2):
21
- t = features[i]
22
- if isinstance(t, torch.Tensor):
23
- new_features.append(t.to(torch.float32))
24
- else:
25
- new_features.append(t)
26
-
27
- return new_features
28
-
29
-
30
- if __name__ == "__main__":
31
- import torchvision.transforms as T
32
- from PIL import Image
33
- from shared import norm, crop_to_divisor
34
-
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
-
37
- image = Image.open("../samples/lex1.jpg")
38
- load_size = 224 # * 3
39
- transform = T.Compose([
40
- T.Resize(load_size, Image.BILINEAR),
41
- # T.CenterCrop(load_size),
42
- T.ToTensor(),
43
- lambda x: crop_to_divisor(x, 16),
44
- norm])
45
-
46
- model = CLIPFeaturizer().cuda()
47
-
48
- results = model(transform(image).cuda().unsqueeze(0))
49
-
50
- print(clip.available_models())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/DAVENet.py DELETED
@@ -1,162 +0,0 @@
1
- # Author: David Harwath
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional
5
- import torch.nn.functional
6
- import torch.nn.functional as F
7
- import torch.utils.model_zoo as model_zoo
8
- import torchvision.models as imagemodels
9
-
10
-
11
- class Davenet(nn.Module):
12
- def __init__(self, embedding_dim=1024):
13
- super(Davenet, self).__init__()
14
- self.embedding_dim = embedding_dim
15
- self.batchnorm1 = nn.BatchNorm2d(1)
16
- self.conv1 = nn.Conv2d(1, 128, kernel_size=(40, 1), stride=(1, 1), padding=(0, 0))
17
- self.conv2 = nn.Conv2d(128, 256, kernel_size=(1, 11), stride=(1, 1), padding=(0, 5))
18
- self.conv3 = nn.Conv2d(256, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
19
- self.conv4 = nn.Conv2d(512, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
20
- self.conv5 = nn.Conv2d(512, embedding_dim, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
21
- self.pool = nn.MaxPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
22
-
23
- def forward(self, x):
24
- if x.dim() == 3:
25
- x = x.unsqueeze(1)
26
- x = self.batchnorm1(x)
27
- x = F.relu(self.conv1(x))
28
- x = F.relu(self.conv2(x))
29
- x = self.pool(x)
30
- x = F.relu(self.conv3(x))
31
- x = self.pool(x)
32
- x = F.relu(self.conv4(x))
33
- x = self.pool(x)
34
- x = F.relu(self.conv5(x))
35
- x = self.pool(x)
36
- x = x.squeeze(2)
37
- return x
38
-
39
-
40
- class Resnet18(imagemodels.ResNet):
41
- def __init__(self, embedding_dim=1024, pretrained=False):
42
- super(Resnet18, self).__init__(imagemodels.resnet.BasicBlock, [2, 2, 2, 2])
43
- if pretrained:
44
- self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet18']))
45
- self.avgpool = None
46
- self.fc = None
47
- self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
48
- self.embedding_dim = embedding_dim
49
- self.pretrained = pretrained
50
-
51
- def forward(self, x):
52
- x = self.conv1(x)
53
- x = self.bn1(x)
54
- x = self.relu(x)
55
- x = self.maxpool(x)
56
- x = self.layer1(x)
57
- x = self.layer2(x)
58
- x = self.layer3(x)
59
- x = self.layer4(x)
60
- x = self.embedder(x)
61
- return x
62
-
63
-
64
- class Resnet34(imagemodels.ResNet):
65
- def __init__(self, embedding_dim=1024, pretrained=False):
66
- super(Resnet34, self).__init__(imagemodels.resnet.BasicBlock, [3, 4, 6, 3])
67
- if pretrained:
68
- self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet34']))
69
- self.avgpool = None
70
- self.fc = None
71
- self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
72
-
73
- def forward(self, x):
74
- x = self.conv1(x)
75
- x = self.bn1(x)
76
- x = self.relu(x)
77
- x = self.maxpool(x)
78
- x = self.layer1(x)
79
- x = self.layer2(x)
80
- x = self.layer3(x)
81
- x = self.layer4(x)
82
- x = self.embedder(x)
83
- return x
84
-
85
-
86
- class Resnet50(imagemodels.ResNet):
87
- def __init__(self, embedding_dim=1024, pretrained=False):
88
- super(Resnet50, self).__init__(imagemodels.resnet.Bottleneck, [3, 4, 6, 3])
89
- if pretrained:
90
- self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet50']))
91
- self.avgpool = None
92
- self.fc = None
93
- self.embedder = nn.Conv2d(2048, embedding_dim, kernel_size=1, stride=1, padding=0)
94
-
95
- def forward(self, x):
96
- x = self.conv1(x)
97
- x = self.bn1(x)
98
- x = self.relu(x)
99
- x = self.maxpool(x)
100
- x = self.layer1(x)
101
- x = self.layer2(x)
102
- x = self.layer3(x)
103
- x = self.layer4(x)
104
- x = self.embedder(x)
105
- return x
106
-
107
-
108
- class VGG16(nn.Module):
109
- def __init__(self, embedding_dim=1024, pretrained=False):
110
- super(VGG16, self).__init__()
111
- seed_model = imagemodels.__dict__['vgg16'](pretrained=pretrained).features
112
- seed_model = nn.Sequential(*list(seed_model.children())[:-1]) # remove final maxpool
113
- last_layer_index = len(list(seed_model.children()))
114
- seed_model.add_module(str(last_layer_index),
115
- nn.Conv2d(512, embedding_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
116
- self.image_model = seed_model
117
-
118
- def forward(self, x):
119
- x = self.image_model(x)
120
- return x
121
-
122
-
123
- def prep(dict):
124
- return {k.replace("module.", ""): v for k, v in dict.items()}
125
-
126
-
127
- class DavenetAudioFeaturizer(nn.Module):
128
-
129
- def __init__(self):
130
- super().__init__()
131
- self.audio_model = Davenet()
132
- self.audio_model.load_state_dict(prep(torch.load("../models/davenet_pt_audio.pth")))
133
-
134
- def forward(self, audio, include_cls):
135
- patch_tokens = self.audio_model(audio).unsqueeze(2)
136
-
137
- if include_cls:
138
- return patch_tokens, None
139
- else:
140
- return patch_tokens
141
-
142
- def get_last_params(self):
143
- return []
144
-
145
-
146
- class DavenetImageFeaturizer(nn.Module):
147
-
148
- def __init__(self):
149
- super().__init__()
150
- self.image_model = VGG16()
151
- self.image_model.load_state_dict(prep(torch.load("../models/davenet_pt_image.pth")))
152
-
153
- def forward(self, image, include_cls):
154
- patch_tokens = self.image_model(image)
155
-
156
- if include_cls:
157
- return patch_tokens, None
158
- else:
159
- return patch_tokens
160
-
161
- def get_last_params(self):
162
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/DINO.py DELETED
@@ -1,451 +0,0 @@
1
- import math
2
- import warnings
3
- from functools import partial
4
-
5
- import timm
6
- import torch
7
- import torch.nn as nn
8
-
9
- eps = 1e-4
10
-
11
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
12
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
13
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
14
- def norm_cdf(x):
15
- # Computes standard normal cumulative distribution function
16
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
17
-
18
- if (mean < a - 2 * std) or (mean > b + 2 * std):
19
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
20
- "The distribution of values may be incorrect.",
21
- stacklevel=2)
22
-
23
- with torch.no_grad():
24
- # Values are generated by using a truncated uniform distribution and
25
- # then using the inverse CDF for the normal distribution.
26
- # Get upper and lower cdf values
27
- l = norm_cdf((a - mean) / std)
28
- u = norm_cdf((b - mean) / std)
29
-
30
- # Uniformly fill tensor with values from [l, u], then translate to
31
- # [2l-1, 2u-1].
32
- tensor.uniform_(2 * l - 1, 2 * u - 1)
33
-
34
- # Use inverse cdf transform for normal distribution to get truncated
35
- # standard normal
36
- tensor.erfinv_()
37
-
38
- # Transform to proper mean, std
39
- tensor.mul_(std * math.sqrt(2.))
40
- tensor.add_(mean)
41
-
42
- # Clamp to ensure it's in the proper range
43
- tensor.clamp_(min=a, max=b)
44
- return tensor
45
-
46
-
47
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
48
- # type: (Tensor, float, float, float, float) -> Tensor
49
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
50
-
51
-
52
-
53
- def drop_path(x, drop_prob: float = 0., training: bool = False):
54
- if drop_prob == 0. or not training:
55
- return x
56
- keep_prob = 1 - drop_prob
57
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
58
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
59
- random_tensor.floor_() # binarize
60
- output = x.div(keep_prob) * random_tensor
61
- return output
62
-
63
-
64
- class DropPath(nn.Module):
65
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
66
- """
67
-
68
- def __init__(self, drop_prob=None):
69
- super(DropPath, self).__init__()
70
- self.drop_prob = drop_prob
71
-
72
- def forward(self, x):
73
- return drop_path(x, self.drop_prob, self.training)
74
-
75
-
76
- class Mlp(nn.Module):
77
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
78
- super().__init__()
79
- out_features = out_features or in_features
80
- hidden_features = hidden_features or in_features
81
- self.fc1 = nn.Linear(in_features, hidden_features)
82
- self.act = act_layer()
83
- self.fc2 = nn.Linear(hidden_features, out_features)
84
- self.drop = nn.Dropout(drop)
85
-
86
- def forward(self, x):
87
- x = self.fc1(x)
88
- x = self.act(x)
89
- x = self.drop(x)
90
- x = self.fc2(x)
91
- x = self.drop(x)
92
- return x
93
-
94
-
95
- class Attention(nn.Module):
96
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
97
- super().__init__()
98
- self.num_heads = num_heads
99
- head_dim = dim // num_heads
100
- self.scale = qk_scale or head_dim ** -0.5
101
-
102
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
103
- self.attn_drop = nn.Dropout(attn_drop)
104
- self.proj = nn.Linear(dim, dim)
105
- self.proj_drop = nn.Dropout(proj_drop)
106
-
107
- def forward(self, x, return_qkv=False):
108
- B, N, C = x.shape
109
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
110
- q, k, v = qkv[0], qkv[1], qkv[2]
111
-
112
- attn = (q @ k.transpose(-2, -1)) * self.scale
113
- attn = attn.softmax(dim=-1)
114
- attn = self.attn_drop(attn)
115
-
116
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
117
- x = self.proj(x)
118
- x = self.proj_drop(x)
119
- return x, attn, qkv
120
-
121
-
122
- class Block(nn.Module):
123
- def __init__(self, dim,
124
- num_heads,
125
- mlp_ratio=4.,
126
- qkv_bias=False,
127
- qk_scale=None,
128
- drop=0.,
129
- attn_drop=0.,
130
- drop_path=0.,
131
- act_layer=nn.GELU,
132
- norm_layer=nn.LayerNorm):
133
- super().__init__()
134
- self.norm1 = norm_layer(dim)
135
- self.attn = Attention(
136
- dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
137
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
138
- self.norm2 = norm_layer(dim)
139
- mlp_hidden_dim = int(dim * mlp_ratio)
140
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
141
-
142
- def forward(self, x, return_attention=False, return_qkv=False):
143
- y, attn, qkv = self.attn(self.norm1(x))
144
- if return_attention:
145
- return attn
146
- x = x + self.drop_path(y)
147
- x = x + self.drop_path(self.mlp(self.norm2(x)))
148
- if return_qkv:
149
- return x, attn, qkv
150
- return x
151
-
152
-
153
- class PatchEmbed(nn.Module):
154
- """ Image to Patch Embedding
155
- """
156
-
157
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
158
- super().__init__()
159
- num_patches = (img_size // patch_size) * (img_size // patch_size)
160
- self.img_size = img_size
161
- self.patch_size = patch_size
162
- self.num_patches = num_patches
163
-
164
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
165
-
166
- def forward(self, x):
167
- B, C, H, W = x.shape
168
- x = self.proj(x).flatten(2).transpose(1, 2)
169
- return x
170
-
171
-
172
- class VisionTransformer(nn.Module):
173
- """ Vision Transformer """
174
-
175
- def __init__(self,
176
- img_size=[224],
177
- patch_size=16,
178
- in_chans=3,
179
- num_classes=0,
180
- embed_dim=768,
181
- depth=12,
182
- num_heads=12,
183
- mlp_ratio=4.,
184
- qkv_bias=False,
185
- qk_scale=None,
186
- drop_rate=0.,
187
- attn_drop_rate=0.,
188
- drop_path_rate=0.,
189
- norm_layer=nn.LayerNorm,
190
- **kwargs):
191
- super().__init__()
192
-
193
- self.num_features = self.embed_dim = embed_dim
194
-
195
- self.patch_embed = PatchEmbed(
196
- img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
197
- num_patches = self.patch_embed.num_patches
198
-
199
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
200
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
201
- self.pos_drop = nn.Dropout(p=drop_rate)
202
-
203
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
204
- self.blocks = nn.ModuleList([
205
- Block(
206
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
207
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
208
- for i in range(depth)])
209
- self.norm = norm_layer(embed_dim)
210
-
211
- # Classifier head
212
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
213
-
214
- trunc_normal_(self.pos_embed, std=.02)
215
- trunc_normal_(self.cls_token, std=.02)
216
- self.apply(self._init_weights)
217
-
218
- def _init_weights(self, m):
219
- if isinstance(m, nn.Linear):
220
- trunc_normal_(m.weight, std=.02)
221
- if isinstance(m, nn.Linear) and m.bias is not None:
222
- nn.init.constant_(m.bias, 0)
223
- elif isinstance(m, nn.LayerNorm):
224
- nn.init.constant_(m.bias, 0)
225
- nn.init.constant_(m.weight, 1.0)
226
-
227
- def interpolate_pos_encoding(self, x, w, h):
228
- npatch = x.shape[1] - 1
229
- N = self.pos_embed.shape[1] - 1
230
- if npatch == N and w == h:
231
- return self.pos_embed
232
- class_pos_embed = self.pos_embed[:, 0]
233
- patch_pos_embed = self.pos_embed[:, 1:]
234
- dim = x.shape[-1]
235
- w0 = w // self.patch_embed.patch_size
236
- h0 = h // self.patch_embed.patch_size
237
- # we add a small number to avoid floating point error in the interpolation
238
- # see discussion at https://github.com/facebookresearch/dino/issues/8
239
- w0, h0 = w0 + 0.1, h0 + 0.1
240
- patch_pos_embed = nn.functional.interpolate(
241
- patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
242
- scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
243
- mode='bicubic',
244
- )
245
- assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
246
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
247
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
248
-
249
- def prepare_tokens(self, x):
250
- B, nc, w, h = x.shape
251
- x = self.patch_embed(x) # patch linear embedding
252
-
253
- # add the [CLS] token to the embed patch tokens
254
- cls_tokens = self.cls_token.expand(B, -1, -1)
255
- x = torch.cat((cls_tokens, x), dim=1)
256
-
257
- # add positional encoding to each token
258
- x = x + self.interpolate_pos_encoding(x, w, h)
259
-
260
- return self.pos_drop(x)
261
-
262
- def forward(self, x):
263
- x = self.prepare_tokens(x)
264
- for blk in self.blocks:
265
- x = blk(x)
266
- x = self.norm(x)
267
- return x[:, 0]
268
-
269
- def forward_feats(self, x):
270
- x = self.prepare_tokens(x)
271
- for blk in self.blocks:
272
- x = blk(x)
273
- x = self.norm(x)
274
- return x
275
-
276
- def get_intermediate_feat(self, x, n=1, norm=True):
277
- x = self.prepare_tokens(x)
278
- # we return the output tokens from the `n` last blocks
279
- feat = []
280
- attns = []
281
- qkvs = []
282
- for i, blk in enumerate(self.blocks):
283
- x, attn, qkv = blk(x, return_qkv=True)
284
- if len(self.blocks) - i <= n:
285
- if norm:
286
- feat.append(self.norm(x))
287
- else:
288
- feat.append(x)
289
- qkvs.append(qkv)
290
- attns.append(attn)
291
- return feat, attns, qkvs
292
-
293
- def get_last_selfattention(self, x):
294
- x = self.prepare_tokens(x)
295
- for i, blk in enumerate(self.blocks):
296
- if i < len(self.blocks) - 1:
297
- x = blk(x)
298
- else:
299
- # return attention of the last block
300
- return blk(x, return_attention=True)
301
-
302
- def get_intermediate_layers(self, x, n=1):
303
- x = self.prepare_tokens(x)
304
- # we return the output tokens from the `n` last blocks
305
- output = []
306
- for i, blk in enumerate(self.blocks):
307
- x = blk(x)
308
- if len(self.blocks) - i <= n:
309
- output.append(self.norm(x))
310
- return output
311
-
312
-
313
- def vit_tiny(patch_size=16, **kwargs):
314
- model = VisionTransformer(
315
- patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
316
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
317
- return model
318
-
319
-
320
- def vit_small(patch_size=16, **kwargs):
321
- model = VisionTransformer(
322
- patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
323
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
324
- return model
325
-
326
-
327
- def vit_base(patch_size=16, **kwargs):
328
- model = VisionTransformer(
329
- patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
330
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
331
- return model
332
-
333
-
334
- class DINOHead(nn.Module):
335
- def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
336
- bottleneck_dim=256):
337
- super().__init__()
338
- nlayers = max(nlayers, 1)
339
- if nlayers == 1:
340
- self.mlp = nn.Linear(in_dim, bottleneck_dim)
341
- else:
342
- layers = [nn.Linear(in_dim, hidden_dim)]
343
- if use_bn:
344
- layers.append(nn.BatchNorm1d(hidden_dim))
345
- layers.append(nn.GELU())
346
- for _ in range(nlayers - 2):
347
- layers.append(nn.Linear(hidden_dim, hidden_dim))
348
- if use_bn:
349
- layers.append(nn.BatchNorm1d(hidden_dim))
350
- layers.append(nn.GELU())
351
- layers.append(nn.Linear(hidden_dim, bottleneck_dim))
352
- self.mlp = nn.Sequential(*layers)
353
- self.apply(self._init_weights)
354
- self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
355
- self.last_layer.weight_g.data.fill_(1)
356
- if norm_last_layer:
357
- self.last_layer.weight_g.requires_grad = False
358
-
359
- def _init_weights(self, m):
360
- if isinstance(m, nn.Linear):
361
- trunc_normal_(m.weight, std=.02)
362
- if isinstance(m, nn.Linear) and m.bias is not None:
363
- nn.init.constant_(m.bias, 0)
364
-
365
- def forward(self, x):
366
- x = self.mlp(x)
367
- x = nn.functional.normalize(x, dim=-1, p=2)
368
- x = self.last_layer(x)
369
- return x
370
-
371
-
372
-
373
- class DINOFeaturizer(nn.Module):
374
-
375
- def __init__(self, arch, patch_size, feat_type):
376
- super().__init__()
377
- self.arch = arch
378
- self.patch_size = patch_size
379
- self.feat_type = feat_type
380
-
381
- self.config = {
382
- "arch": arch,
383
- "patch_size": patch_size,
384
- "feat_type": feat_type
385
- }
386
-
387
- self.model = vit_small(
388
- patch_size=patch_size,
389
- num_classes=0)
390
-
391
- if "3d-dino" in arch:
392
- state_dict = torch.load("../models/3d-dino-co3d.pth")["teacher"]
393
- state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
394
- state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
395
- elif "iarpa-dino" in arch:
396
- state_dict = torch.load("../models/dino_iarpa.pth")["teacher"]
397
- state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
398
- state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
399
- elif "chk-dino" in arch:
400
- state_dict = torch.load("../models/dino_deitsmall16_pretrain_full_checkpoint.pth")["teacher"]
401
- state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
402
- state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
403
- elif "ft_dino" in arch:
404
- arch = "_".join(arch.split("_")[:-1])
405
- state_dict = torch.load("../models/{}.pth".format(arch))["teacher"]
406
- state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
407
- state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
408
- elif "dino" in arch:
409
- state_dict = torch.hub.load('facebookresearch/dino:main', self.arch).state_dict()
410
- else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
411
- temp_model = timm.create_model(self.arch, pretrained=True)
412
- state_dict = temp_model.state_dict()
413
- del state_dict['head.weight']
414
- del state_dict['head.bias']
415
-
416
- self.model.load_state_dict(state_dict, strict=True)
417
-
418
- if arch == "vit_small":
419
- self.n_feats = 384
420
- else:
421
- self.n_feats = 768
422
-
423
- def get_cls_token(self, img):
424
- return self.model.forward(img)
425
-
426
- def forward(self, img, n=1, include_cls=False):
427
- assert (img.shape[2] % self.patch_size == 0)
428
- assert (img.shape[3] % self.patch_size == 0)
429
-
430
- feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
431
- feat, attn, qkv = feat[0], attn[0], qkv[0]
432
-
433
- feat_h = img.shape[2] // self.patch_size
434
- feat_w = img.shape[3] // self.patch_size
435
-
436
- if self.feat_type == "token":
437
- image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
438
- cls_feat = feat[:, 0, :]
439
- elif self.feat_type == "key":
440
- x = qkv[1, :, :, 1:, :] # remove cls token
441
- desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1)
442
- image_feat = desc.reshape(desc.shape[0], feat_h, feat_w, desc.shape[2]) \
443
- .permute(0, 3, 1, 2)
444
- cls_feat = None
445
- else:
446
- raise ValueError("Unknown feat type:{}".format(self.feat_type))
447
-
448
- if include_cls:
449
- return image_feat, cls_feat
450
-
451
- return image_feat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/DINOv2.py DELETED
@@ -1,49 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class DINOv2Featurizer(nn.Module):
6
-
7
- def __init__(self):
8
- super().__init__()
9
- self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
10
- # self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
11
- self.model.eval()
12
- self.config = {}
13
-
14
- def get_cls_token(self, img):
15
- pass
16
-
17
- def forward(self, img, include_cls):
18
- feature_dict = self.model.forward_features(img)
19
- _, _, h, w = img.shape
20
- new_h, new_w = h // 14, w // 14
21
- b, _, c = feature_dict["x_norm_patchtokens"].shape
22
- spatial_tokens = feature_dict["x_norm_patchtokens"].permute(0, 2, 1).reshape(b, c, new_h, new_w)
23
-
24
- if include_cls:
25
- return spatial_tokens, feature_dict["x_norm_clstoken"]
26
- else:
27
- return spatial_tokens
28
-
29
-
30
- if __name__ == "__main__":
31
- import torchvision.transforms as T
32
- from PIL import Image
33
- from shared import norm, crop_to_divisor
34
-
35
- device = "cuda" if torch.cuda.is_available() else "cpu"
36
-
37
- image = Image.open("../../samples/dog_man_1_crop.jpg")
38
- load_size = 224 # * 3
39
- transform = T.Compose([
40
- T.Resize(load_size, Image.BILINEAR),
41
- T.CenterCrop(load_size),
42
- T.ToTensor(),
43
- norm])
44
-
45
- model = DINOv2Featurizer().cuda()
46
-
47
- results = model(transform(image).cuda().unsqueeze(0), include_cls=False)
48
-
49
- print(results.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/Hubert.py DELETED
@@ -1,70 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import Wav2Vec2Processor, HubertModel, HubertConfig
4
- from transformers.pytorch_utils import Conv1D
5
-
6
- class HubertAudioTransform():
7
-
8
- def __init__(self):
9
- self.processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
10
-
11
- def __call__(self, audio):
12
- return self.processor(audio, return_tensors="pt", sampling_rate=16000).input_values.squeeze(0)
13
-
14
-
15
- def copy_conv(l):
16
- new_l = Conv1D()
17
-
18
-
19
- class Hubert(nn.Module):
20
- def __init__(self):
21
- super().__init__()
22
- model1 = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
23
- config = model1.config
24
- del model1
25
- config.layer_norm_eps = 1e-4
26
- self.model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", config=config)
27
- self.config = dict()
28
-
29
-
30
- def forward(self, audio, include_cls):
31
- outputs = self.model(audio)
32
- # outputs = deepspeed.checkpointing.checkpoint(self.model, audio)
33
-
34
- patch_tokens = outputs.last_hidden_state.permute(0, 2, 1).unsqueeze(2)
35
-
36
- # return patch_tokens
37
- if include_cls:
38
- return patch_tokens, None
39
- else:
40
- return patch_tokens
41
-
42
- def get_last_params(self):
43
- return self.model.encoder.layers[-1].parameters()
44
-
45
-
46
- if __name__ == "__main__":
47
- import librosa
48
- from shared import pca, remove_axes
49
- import matplotlib.pyplot as plt
50
- from pytorch_lightning import seed_everything
51
-
52
- audio, _ = librosa.load("../../samples/example.wav", sr=16000)
53
- audio = torch.from_numpy(audio).unsqueeze(0).to("cuda")
54
-
55
- model = Hubert().to("cuda")
56
- embeddings = model.forward(audio, include_cls=False)
57
-
58
- print(embeddings.shape)
59
- seed_everything(0)
60
-
61
- with torch.no_grad():
62
- [pca_feats], _ = pca([embeddings])
63
- pca_feats = torch.broadcast_to(
64
- pca_feats, (pca_feats.shape[0], pca_feats.shape[1], 25, pca_feats.shape[3]))
65
- fig, axes = plt.subplots(2, 1, figsize=(10, 7))
66
- axes[1].imshow(pca_feats.cpu().squeeze(0).permute(1, 2, 0))
67
- remove_axes(axes)
68
- plt.tight_layout()
69
- plt.show()
70
- print("here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/ImageBind.py DELETED
@@ -1,2033 +0,0 @@
1
- import gzip
2
- import html
3
- import io
4
- import logging
5
- import math
6
- import os
7
- from functools import lru_cache
8
- from functools import partial
9
- from types import SimpleNamespace
10
- from typing import Callable, List
11
- from typing import Optional
12
-
13
- import einops
14
- import ftfy
15
- import numpy as np
16
- import regex as re
17
- import torch
18
- import torch.nn as nn
19
- import torch.nn.functional as F
20
- import torch.utils.checkpoint as checkpoint
21
- import torchaudio
22
- import torchvision.transforms as T
23
- from PIL import Image
24
- from timm.models.layers import DropPath, trunc_normal_
25
- from torchvision import transforms
26
- import matplotlib.pyplot as plt
27
- from iopath.common.file_io import g_pathmgr
28
-
29
-
30
- class Attention(nn.Module):
31
- def __init__(
32
- self,
33
- dim,
34
- num_heads=8,
35
- qkv_bias=False,
36
- qk_scale=None,
37
- attn_drop=0.0,
38
- proj_drop=0.0,
39
- ):
40
- super().__init__()
41
- self.num_heads = num_heads
42
- head_dim = dim // num_heads
43
- # NOTE scale factor was wrong in my original version,
44
- # can set manually to be compat with prev weights
45
- self.scale = qk_scale or head_dim ** -0.5
46
-
47
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
- self.attn_drop = nn.Dropout(attn_drop)
49
- self.proj = nn.Linear(dim, dim)
50
- self.proj_drop = nn.Dropout(proj_drop)
51
-
52
- def forward(self, x):
53
- B, N, C = x.shape
54
- qkv = (
55
- self.qkv(x)
56
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
57
- .permute(2, 0, 3, 1, 4)
58
- )
59
- q, k, v = (
60
- qkv[0],
61
- qkv[1],
62
- qkv[2],
63
- ) # make torchscript happy (cannot use tensor as tuple)
64
-
65
- attn = (q @ k.transpose(-2, -1)) * self.scale
66
- attn = attn.softmax(dim=-1)
67
- attn = self.attn_drop(attn)
68
-
69
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
70
- x = self.proj(x)
71
- x = self.proj_drop(x)
72
- return x
73
-
74
-
75
- class Mlp(nn.Module):
76
- def __init__(
77
- self,
78
- in_features,
79
- hidden_features=None,
80
- out_features=None,
81
- act_layer=nn.GELU,
82
- drop=0.0,
83
- ):
84
- super().__init__()
85
- out_features = out_features or in_features
86
- hidden_features = hidden_features or in_features
87
- self.fc1 = nn.Linear(in_features, hidden_features)
88
- self.act = act_layer()
89
- self.fc2 = nn.Linear(hidden_features, out_features)
90
- self.drop = nn.Dropout(drop)
91
-
92
- def forward(self, x):
93
- x = self.fc1(x)
94
- x = self.act(x)
95
- x = self.drop(x)
96
- x = self.fc2(x)
97
- x = self.drop(x)
98
- return x
99
-
100
-
101
- class MultiheadAttention(nn.MultiheadAttention):
102
- def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
103
- return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
104
-
105
-
106
- class ViTAttention(Attention):
107
- def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
108
- assert attn_mask is None
109
- return super().forward(x)
110
-
111
-
112
- class BlockWithMasking(nn.Module):
113
- def __init__(
114
- self,
115
- dim: int,
116
- attn_target: Callable,
117
- mlp_ratio: int = 4,
118
- act_layer: Callable = nn.GELU,
119
- norm_layer: Callable = nn.LayerNorm,
120
- ffn_dropout_rate: float = 0.0,
121
- drop_path: float = 0.0,
122
- layer_scale_type: str = None,
123
- layer_scale_init_value: float = 1e-4,
124
- ):
125
- super().__init__()
126
-
127
- assert not isinstance(
128
- attn_target, nn.Module
129
- ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
130
- self.attn = attn_target()
131
- if drop_path > 0.0:
132
- self.drop_path = DropPath(drop_path)
133
- else:
134
- self.drop_path = nn.Identity()
135
- self.norm_1 = norm_layer(dim)
136
- mlp_hidden_dim = int(mlp_ratio * dim)
137
- self.mlp = Mlp(
138
- in_features=dim,
139
- hidden_features=mlp_hidden_dim,
140
- act_layer=act_layer,
141
- drop=ffn_dropout_rate,
142
- )
143
- self.norm_2 = norm_layer(dim)
144
- self.layer_scale_type = layer_scale_type
145
- if self.layer_scale_type is not None:
146
- assert self.layer_scale_type in [
147
- "per_channel",
148
- "scalar",
149
- ], f"Found Layer scale type {self.layer_scale_type}"
150
- if self.layer_scale_type == "per_channel":
151
- # one gamma value per channel
152
- gamma_shape = [1, 1, dim]
153
- elif self.layer_scale_type == "scalar":
154
- # single gamma value for all channels
155
- gamma_shape = [1, 1, 1]
156
- # two gammas: for each part of the fwd in the encoder
157
- self.layer_scale_gamma1 = nn.Parameter(
158
- torch.ones(size=gamma_shape) * layer_scale_init_value,
159
- requires_grad=True,
160
- )
161
- self.layer_scale_gamma2 = nn.Parameter(
162
- torch.ones(size=gamma_shape) * layer_scale_init_value,
163
- requires_grad=True,
164
- )
165
-
166
- def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
167
- if self.layer_scale_type is None:
168
- x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
169
- x = x + self.drop_path(self.mlp(self.norm_2(x)))
170
- else:
171
- x = (
172
- x
173
- + self.drop_path(self.attn(self.norm_1(x), attn_mask))
174
- * self.layer_scale_gamma1
175
- )
176
- x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
177
- return x
178
-
179
-
180
- _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
181
-
182
-
183
- class SimpleTransformer(nn.Module):
184
- def __init__(
185
- self,
186
- attn_target: Callable,
187
- embed_dim: int,
188
- num_blocks: int,
189
- block: Callable = BlockWithMasking,
190
- pre_transformer_layer: Callable = None,
191
- post_transformer_layer: Callable = None,
192
- drop_path_rate: float = 0.0,
193
- drop_path_type: str = "progressive",
194
- norm_layer: Callable = _LAYER_NORM,
195
- mlp_ratio: int = 4,
196
- ffn_dropout_rate: float = 0.0,
197
- layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
198
- layer_scale_init_value: float = 1e-4, # from cait; float
199
- weight_init_style: str = "jax", # possible values jax or pytorch
200
- ):
201
- """
202
- Simple Transformer with the following features
203
- 1. Supports masked attention
204
- 2. Supports DropPath
205
- 3. Supports LayerScale
206
- 4. Supports Dropout in Attention and FFN
207
- 5. Makes few assumptions about the input except that it is a Tensor
208
- """
209
- super().__init__()
210
- self.pre_transformer_layer = pre_transformer_layer
211
- if drop_path_type == "progressive":
212
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
213
- elif drop_path_type == "uniform":
214
- dpr = [drop_path_rate for i in range(num_blocks)]
215
- else:
216
- raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
217
-
218
- self.blocks = nn.Sequential(
219
- *[
220
- block(
221
- dim=embed_dim,
222
- attn_target=attn_target,
223
- mlp_ratio=mlp_ratio,
224
- ffn_dropout_rate=ffn_dropout_rate,
225
- drop_path=dpr[i],
226
- norm_layer=norm_layer,
227
- layer_scale_type=layer_scale_type,
228
- layer_scale_init_value=layer_scale_init_value,
229
- )
230
- for i in range(num_blocks)
231
- ]
232
- )
233
- self.post_transformer_layer = post_transformer_layer
234
- self.weight_init_style = weight_init_style
235
- self.apply(self._init_weights)
236
-
237
- def _init_weights(self, m):
238
- if isinstance(m, nn.Linear):
239
- if self.weight_init_style == "jax":
240
- # Based on MAE and official Jax ViT implementation
241
- torch.nn.init.xavier_uniform_(m.weight)
242
- elif self.weight_init_style == "pytorch":
243
- # PyTorch ViT uses trunc_normal_
244
- trunc_normal_(m.weight, std=0.02)
245
-
246
- if m.bias is not None:
247
- nn.init.constant_(m.bias, 0)
248
- elif isinstance(m, (nn.LayerNorm)):
249
- nn.init.constant_(m.bias, 0)
250
- nn.init.constant_(m.weight, 1.0)
251
-
252
- def forward(
253
- self,
254
- tokens: torch.Tensor,
255
- attn_mask: torch.Tensor = None,
256
- use_checkpoint: bool = False,
257
- checkpoint_every_n: int = 1,
258
- checkpoint_blk_ids: List[int] = None,
259
- ):
260
- """
261
- Inputs
262
- - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
263
- - attn: mask of shape L x L
264
-
265
- Output
266
- - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
267
- """
268
- if self.pre_transformer_layer:
269
- tokens = self.pre_transformer_layer(tokens)
270
- if use_checkpoint and checkpoint_blk_ids is None:
271
- checkpoint_blk_ids = [
272
- blk_id
273
- for blk_id in range(len(self.blocks))
274
- if blk_id % checkpoint_every_n == 0
275
- ]
276
- if checkpoint_blk_ids:
277
- checkpoint_blk_ids = set(checkpoint_blk_ids)
278
- for blk_id, blk in enumerate(self.blocks):
279
- if use_checkpoint and blk_id in checkpoint_blk_ids:
280
- tokens = checkpoint.checkpoint(
281
- blk, tokens, attn_mask, use_reentrant=False
282
- )
283
- else:
284
- tokens = blk(tokens, attn_mask=attn_mask)
285
- if self.post_transformer_layer:
286
- tokens = self.post_transformer_layer(tokens)
287
- return tokens
288
-
289
-
290
- def get_sinusoid_encoding_table(n_position, d_hid):
291
- """Sinusoid position encoding table"""
292
-
293
- # TODO: make it with torch instead of numpy
294
- def get_position_angle_vec(position):
295
- return [
296
- position / np.power(10000, 2 * (hid_j // 2) / d_hid)
297
- for hid_j in range(d_hid)
298
- ]
299
-
300
- sinusoid_table = np.array(
301
- [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
302
- )
303
- sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
304
- sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
305
-
306
- return torch.FloatTensor(sinusoid_table).unsqueeze(0)
307
-
308
-
309
- def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
310
- N = pos_embed.shape[1]
311
- if N == target_spatial_size:
312
- return pos_embed
313
- dim = pos_embed.shape[-1]
314
- # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
315
- pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
316
- pos_embed = nn.functional.interpolate(
317
- pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
318
- 0, 3, 1, 2
319
- ),
320
- scale_factor=math.sqrt(target_spatial_size / N),
321
- mode="bicubic",
322
- )
323
- if updated:
324
- pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
325
- pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
326
- return pos_embed
327
-
328
-
329
- def interpolate_pos_encoding(
330
- npatch_per_img,
331
- pos_embed,
332
- patches_layout,
333
- input_shape=None,
334
- first_patch_idx=1,
335
- ):
336
- assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
337
- N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
338
- if npatch_per_img == N:
339
- return pos_embed
340
-
341
- # assert (
342
- # patches_layout[-1] == patches_layout[-2]
343
- # ), "Interpolation of pos embed not supported for non-square layouts"
344
-
345
- class_emb = pos_embed[:, :first_patch_idx]
346
- pos_embed = pos_embed[:, first_patch_idx:]
347
-
348
- if input_shape is None or patches_layout[0] == 1:
349
- # simple 2D pos embedding, no temporal component
350
- pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
351
- elif patches_layout[0] > 1:
352
- # pos embed has a temporal component
353
- assert len(input_shape) == 4, "temporal interpolation not supported"
354
- # we only support 2D interpolation in this case
355
- num_frames = patches_layout[0]
356
- num_spatial_tokens = patches_layout[1] * patches_layout[2]
357
- pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
358
- # interpolate embedding for zeroth frame
359
- pos_embed = interpolate_pos_encoding_2d(
360
- npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
361
- )
362
- else:
363
- raise ValueError("This type of interpolation isn't implemented")
364
-
365
- return torch.cat((class_emb, pos_embed), dim=1)
366
-
367
-
368
- def _get_pos_embedding(
369
- npatch_per_img,
370
- pos_embed,
371
- patches_layout,
372
- input_shape,
373
- first_patch_idx=1,
374
- ):
375
- pos_embed = interpolate_pos_encoding(
376
- npatch_per_img,
377
- pos_embed,
378
- patches_layout,
379
- input_shape=input_shape,
380
- first_patch_idx=first_patch_idx,
381
- )
382
- return pos_embed
383
-
384
-
385
- class VerboseNNModule(nn.Module):
386
- """
387
- Wrapper around nn.Module that prints registered buffers and parameter names.
388
- """
389
-
390
- @staticmethod
391
- def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
392
- st = (
393
- "("
394
- + name
395
- + "): "
396
- + "tensor("
397
- + str(tuple(tensor[1].shape))
398
- + ", requires_grad="
399
- + str(tensor[1].requires_grad)
400
- + ")\n"
401
- )
402
- return st
403
-
404
- def extra_repr(self) -> str:
405
- named_modules = set()
406
- for p in self.named_modules():
407
- named_modules.update([p[0]])
408
- named_modules = list(named_modules)
409
-
410
- string_repr = ""
411
- for p in self.named_parameters():
412
- name = p[0].split(".")[0]
413
- if name not in named_modules:
414
- string_repr += self.get_readable_tensor_repr(name, p)
415
-
416
- for p in self.named_buffers():
417
- name = p[0].split(".")[0]
418
- string_repr += self.get_readable_tensor_repr(name, p)
419
-
420
- return string_repr
421
-
422
-
423
- class PatchEmbedGeneric(nn.Module):
424
- """
425
- PatchEmbed from Hydra
426
- """
427
-
428
- def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
429
- super().__init__()
430
-
431
- if len(proj_stem) > 1:
432
- self.proj = nn.Sequential(*proj_stem)
433
- else:
434
- # Special case to be able to load pre-trained models that were
435
- # trained with a standard stem
436
- self.proj = proj_stem[0]
437
- self.norm_layer = norm_layer
438
-
439
- def get_patch_layout(self, img_size):
440
- with torch.no_grad():
441
- dummy_img = torch.zeros(
442
- [
443
- 1,
444
- ]
445
- + img_size
446
- )
447
- dummy_out = self.proj(dummy_img)
448
- embed_dim = dummy_out.shape[1]
449
- patches_layout = tuple(dummy_out.shape[2:])
450
- num_patches = np.prod(patches_layout)
451
- return patches_layout, num_patches, embed_dim
452
-
453
- def forward(self, x):
454
- x = self.proj(x)
455
- # B C (T) H W -> B (T)HW C
456
- x = x.flatten(2).transpose(1, 2)
457
- if self.norm_layer is not None:
458
- x = self.norm_layer(x)
459
- return x
460
-
461
-
462
- class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
463
- def __init__(
464
- self,
465
- patches_layout: List,
466
- num_patches: int,
467
- num_cls_tokens: int,
468
- embed_dim: int,
469
- learnable: bool,
470
- ) -> None:
471
- super().__init__()
472
- self.num_cls_tokens = num_cls_tokens
473
- self.patches_layout = patches_layout
474
- self.num_patches = num_patches
475
- self.num_tokens = num_cls_tokens + num_patches
476
- self.learnable = learnable
477
- if self.learnable:
478
- self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
479
- trunc_normal_(self.pos_embed, std=0.02)
480
- else:
481
- self.register_buffer(
482
- "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
483
- )
484
-
485
- def get_pos_embedding(self, vision_input, all_vision_tokens):
486
- input_shape = vision_input.shape
487
- pos_embed = _get_pos_embedding(
488
- all_vision_tokens.size(1) - self.num_cls_tokens,
489
- pos_embed=self.pos_embed,
490
- patches_layout=self.patches_layout,
491
- input_shape=input_shape,
492
- first_patch_idx=self.num_cls_tokens,
493
- )
494
- return pos_embed
495
-
496
-
497
- class RGBDTPreprocessor(VerboseNNModule):
498
- def __init__(
499
- self,
500
- rgbt_stem: PatchEmbedGeneric,
501
- depth_stem: PatchEmbedGeneric,
502
- img_size: List = (3, 224, 224),
503
- num_cls_tokens: int = 1,
504
- pos_embed_fn: Callable = None,
505
- use_type_embed: bool = False,
506
- init_param_style: str = "openclip",
507
- ) -> None:
508
- super().__init__()
509
- stem = rgbt_stem if rgbt_stem is not None else depth_stem
510
- (
511
- self.patches_layout,
512
- self.num_patches,
513
- self.embed_dim,
514
- ) = stem.get_patch_layout(img_size)
515
- self.rgbt_stem = rgbt_stem
516
- self.depth_stem = depth_stem
517
- self.use_pos_embed = pos_embed_fn is not None
518
- self.use_type_embed = use_type_embed
519
- self.num_cls_tokens = num_cls_tokens
520
-
521
- if self.use_pos_embed:
522
- self.pos_embedding_helper = pos_embed_fn(
523
- patches_layout=self.patches_layout,
524
- num_cls_tokens=num_cls_tokens,
525
- num_patches=self.num_patches,
526
- embed_dim=self.embed_dim,
527
- )
528
- if self.num_cls_tokens > 0:
529
- self.cls_token = nn.Parameter(
530
- torch.zeros(1, self.num_cls_tokens, self.embed_dim)
531
- )
532
- if self.use_type_embed:
533
- self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
534
-
535
- self.init_parameters(init_param_style)
536
-
537
- @torch.no_grad()
538
- def init_parameters(self, init_param_style):
539
- if init_param_style == "openclip":
540
- # OpenCLIP style initialization
541
- scale = self.embed_dim ** -0.5
542
- if self.use_pos_embed:
543
- nn.init.normal_(self.pos_embedding_helper.pos_embed)
544
- self.pos_embedding_helper.pos_embed *= scale
545
-
546
- if self.num_cls_tokens > 0:
547
- nn.init.normal_(self.cls_token)
548
- self.cls_token *= scale
549
- elif init_param_style == "vit":
550
- self.cls_token.data.fill_(0)
551
- else:
552
- raise ValueError(f"Unknown init {init_param_style}")
553
-
554
- if self.use_type_embed:
555
- nn.init.normal_(self.type_embed)
556
-
557
- def get_pos_emb_2(self, input, stem):
558
- patches = stem.proj(input)
559
- target_size = patches.shape[-2:]
560
- original_size = list(self.pos_embedding_helper.patches_layout)[-2:]
561
-
562
- orig_ce = self.pos_embedding_helper.pos_embed[:, 0, :]
563
- orig_pe = ((self.pos_embedding_helper.pos_embed[:, 1:, :]
564
- .reshape(1, *original_size, self.embed_dim))
565
- .permute(0, 3, 1, 2))
566
-
567
- new_pe = F.interpolate(orig_pe, size=target_size, mode="bicubic")
568
-
569
- new_full_pe = torch.cat([orig_ce.unsqueeze(1), new_pe.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)],
570
- dim=1)
571
-
572
- return new_full_pe
573
-
574
- def tokenize_input_and_cls_pos(self, input, stem, mask):
575
- # tokens is of shape B x L x D
576
- tokens = stem(input)
577
- assert tokens.ndim == 3
578
- assert tokens.shape[2] == self.embed_dim
579
- B = tokens.shape[0]
580
- if self.num_cls_tokens > 0:
581
- class_tokens = self.cls_token.expand(
582
- B, -1, -1
583
- ) # stole class_tokens impl from Phil Wang, thanks
584
- tokens = torch.cat((class_tokens, tokens), dim=1)
585
- if self.use_pos_embed:
586
- pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
587
- # pos_embed = self.get_pos_emb_2(input, stem)
588
- tokens = tokens + pos_embed
589
- if self.use_type_embed:
590
- tokens = tokens + self.type_embed.expand(B, -1, -1)
591
- return tokens
592
-
593
- def forward(self, vision=None, depth=None, patch_mask=None):
594
- if patch_mask is not None:
595
- raise NotImplementedError()
596
-
597
- if vision is not None:
598
- vision_tokens = self.tokenize_input_and_cls_pos(
599
- vision, self.rgbt_stem, patch_mask
600
- )
601
-
602
- if depth is not None:
603
- depth_tokens = self.tokenize_input_and_cls_pos(
604
- depth, self.depth_stem, patch_mask
605
- )
606
-
607
- # aggregate tokens
608
- if vision is not None and depth is not None:
609
- final_tokens = vision_tokens + depth_tokens
610
- else:
611
- final_tokens = vision_tokens if vision is not None else depth_tokens
612
- return_dict = {
613
- "trunk": {
614
- "tokens": final_tokens,
615
- },
616
- "head": {},
617
- }
618
- return return_dict
619
-
620
-
621
- class AudioPreprocessor(RGBDTPreprocessor):
622
- def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
623
- super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
624
-
625
- def forward(self, audio=None):
626
- return super().forward(vision=audio)
627
-
628
-
629
- class ThermalPreprocessor(RGBDTPreprocessor):
630
- def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
631
- super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
632
-
633
- def forward(self, thermal=None):
634
- return super().forward(vision=thermal)
635
-
636
-
637
- def build_causal_attention_mask(context_length):
638
- # lazily create causal attention mask, with full attention between the vision tokens
639
- # pytorch uses additive attention mask; fill with -inf
640
- mask = torch.empty(context_length, context_length, requires_grad=False)
641
- mask.fill_(float("-inf"))
642
- mask.triu_(1) # zero out the lower diagonal
643
- return mask
644
-
645
-
646
- class TextPreprocessor(VerboseNNModule):
647
- def __init__(
648
- self,
649
- vocab_size: int,
650
- context_length: int,
651
- embed_dim: int,
652
- causal_masking: bool,
653
- supply_seq_len_to_head: bool = True,
654
- num_cls_tokens: int = 0,
655
- init_param_style: str = "openclip",
656
- ) -> None:
657
- super().__init__()
658
- self.vocab_size = vocab_size
659
- self.context_length = context_length
660
- self.token_embedding = nn.Embedding(vocab_size, embed_dim)
661
- self.pos_embed = nn.Parameter(
662
- torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
663
- )
664
- self.causal_masking = causal_masking
665
- if self.causal_masking:
666
- mask = build_causal_attention_mask(self.context_length)
667
- # register the mask as a buffer, so it can be moved to the right device
668
- self.register_buffer("mask", mask)
669
-
670
- self.supply_seq_len_to_head = supply_seq_len_to_head
671
- self.num_cls_tokens = num_cls_tokens
672
- self.embed_dim = embed_dim
673
- if num_cls_tokens > 0:
674
- assert self.causal_masking is False, "Masking + CLS token isn't implemented"
675
- self.cls_token = nn.Parameter(
676
- torch.zeros(1, self.num_cls_tokens, embed_dim)
677
- )
678
-
679
- self.init_parameters(init_param_style)
680
-
681
- @torch.no_grad()
682
- def init_parameters(self, init_param_style="openclip"):
683
- # OpenCLIP style initialization
684
- nn.init.normal_(self.token_embedding.weight, std=0.02)
685
- nn.init.normal_(self.pos_embed, std=0.01)
686
-
687
- if init_param_style == "openclip":
688
- # OpenCLIP style initialization
689
- scale = self.embed_dim ** -0.5
690
- if self.num_cls_tokens > 0:
691
- nn.init.normal_(self.cls_token)
692
- self.cls_token *= scale
693
- elif init_param_style == "vit":
694
- self.cls_token.data.fill_(0)
695
- else:
696
- raise ValueError(f"Unknown init {init_param_style}")
697
-
698
- def forward(self, text):
699
- # text tokens are of shape B x L x D
700
- text_tokens = self.token_embedding(text)
701
- # concat CLS tokens if any
702
- if self.num_cls_tokens > 0:
703
- B = text_tokens.shape[0]
704
- class_tokens = self.cls_token.expand(
705
- B, -1, -1
706
- ) # stole class_tokens impl from Phil Wang, thanks
707
- text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
708
- text_tokens = text_tokens + self.pos_embed
709
- return_dict = {
710
- "trunk": {
711
- "tokens": text_tokens,
712
- },
713
- "head": {},
714
- }
715
- # Compute sequence length after adding CLS tokens
716
- if self.supply_seq_len_to_head:
717
- text_lengths = text.argmax(dim=-1)
718
- return_dict["head"] = {
719
- "seq_len": text_lengths,
720
- }
721
- if self.causal_masking:
722
- return_dict["trunk"].update({"attn_mask": self.mask})
723
- return return_dict
724
-
725
-
726
- class Im2Video(nn.Module):
727
- """Convert an image into a trivial video."""
728
-
729
- def __init__(self, time_dim=2):
730
- super().__init__()
731
- self.time_dim = time_dim
732
-
733
- def forward(self, x):
734
- if x.ndim == 4:
735
- # B, C, H, W -> B, C, T, H, W
736
- return x.unsqueeze(self.time_dim)
737
- elif x.ndim == 5:
738
- return x
739
- else:
740
- raise ValueError(f"Dimension incorrect {x.shape}")
741
-
742
-
743
- class PadIm2Video(Im2Video):
744
- def __init__(self, ntimes, pad_type, time_dim=2):
745
- super().__init__(time_dim=time_dim)
746
- assert ntimes > 0
747
- assert pad_type in ["zero", "repeat"]
748
- self.ntimes = ntimes
749
- self.pad_type = pad_type
750
-
751
- def forward(self, x):
752
- x = super().forward(x)
753
- if x.shape[self.time_dim] == 1:
754
- if self.pad_type == "repeat":
755
- new_shape = [1] * len(x.shape)
756
- new_shape[self.time_dim] = self.ntimes
757
- x = x.repeat(new_shape)
758
- elif self.pad_type == "zero":
759
- padarg = [0, 0] * len(x.shape)
760
- padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
761
- x = nn.functional.pad(x, padarg)
762
- return x
763
-
764
-
765
- # Modified from github.com/openai/CLIP
766
- @lru_cache()
767
- def bytes_to_unicode():
768
- """
769
- Returns list of utf-8 byte and a corresponding list of unicode strings.
770
- The reversible bpe codes work on unicode strings.
771
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
772
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
773
- This is a signficant percentage of your normal, say, 32K bpe vocab.
774
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
775
- And avoids mapping to whitespace/control characters the bpe code barfs on.
776
- """
777
- bs = (
778
- list(range(ord("!"), ord("~") + 1))
779
- + list(range(ord("¡"), ord("¬") + 1))
780
- + list(range(ord("®"), ord("ÿ") + 1))
781
- )
782
- cs = bs[:]
783
- n = 0
784
- for b in range(2 ** 8):
785
- if b not in bs:
786
- bs.append(b)
787
- cs.append(2 ** 8 + n)
788
- n += 1
789
- cs = [chr(n) for n in cs]
790
- return dict(zip(bs, cs))
791
-
792
-
793
- def get_pairs(word):
794
- """Return set of symbol pairs in a word.
795
- Word is represented as tuple of symbols (symbols being variable-length strings).
796
- """
797
- pairs = set()
798
- prev_char = word[0]
799
- for char in word[1:]:
800
- pairs.add((prev_char, char))
801
- prev_char = char
802
- return pairs
803
-
804
-
805
- def basic_clean(text):
806
- text = ftfy.fix_text(text)
807
- text = html.unescape(html.unescape(text))
808
- return text.strip()
809
-
810
-
811
- def whitespace_clean(text):
812
- text = re.sub(r"\s+", " ", text)
813
- text = text.strip()
814
- return text
815
-
816
-
817
- class SimpleTokenizer(object):
818
- def __init__(self, bpe_path: str, context_length=77):
819
- self.byte_encoder = bytes_to_unicode()
820
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
821
-
822
- with g_pathmgr.open(bpe_path, "rb") as fh:
823
- bpe_bytes = io.BytesIO(fh.read())
824
- merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
825
- merges = merges[1: 49152 - 256 - 2 + 1]
826
- merges = [tuple(merge.split()) for merge in merges]
827
- vocab = list(bytes_to_unicode().values())
828
- vocab = vocab + [v + "</w>" for v in vocab]
829
- for merge in merges:
830
- vocab.append("".join(merge))
831
- vocab.extend(["<|startoftext|>", "<|endoftext|>"])
832
- self.encoder = dict(zip(vocab, range(len(vocab))))
833
- self.decoder = {v: k for k, v in self.encoder.items()}
834
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
835
- self.cache = {
836
- "<|startoftext|>": "<|startoftext|>",
837
- "<|endoftext|>": "<|endoftext|>",
838
- }
839
- self.pat = re.compile(
840
- r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
841
- re.IGNORECASE,
842
- )
843
- self.context_length = context_length
844
-
845
- def bpe(self, token):
846
- if token in self.cache:
847
- return self.cache[token]
848
- word = tuple(token[:-1]) + (token[-1] + "</w>",)
849
- pairs = get_pairs(word)
850
-
851
- if not pairs:
852
- return token + "</w>"
853
-
854
- while True:
855
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
856
- if bigram not in self.bpe_ranks:
857
- break
858
- first, second = bigram
859
- new_word = []
860
- i = 0
861
- while i < len(word):
862
- try:
863
- j = word.index(first, i)
864
- new_word.extend(word[i:j])
865
- i = j
866
- except:
867
- new_word.extend(word[i:])
868
- break
869
-
870
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
871
- new_word.append(first + second)
872
- i += 2
873
- else:
874
- new_word.append(word[i])
875
- i += 1
876
- new_word = tuple(new_word)
877
- word = new_word
878
- if len(word) == 1:
879
- break
880
- else:
881
- pairs = get_pairs(word)
882
- word = " ".join(word)
883
- self.cache[token] = word
884
- return word
885
-
886
- def encode(self, text):
887
- bpe_tokens = []
888
- text = whitespace_clean(basic_clean(text)).lower()
889
- for token in re.findall(self.pat, text):
890
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
891
- bpe_tokens.extend(
892
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
893
- )
894
- return bpe_tokens
895
-
896
- def decode(self, tokens):
897
- text = "".join([self.decoder[token] for token in tokens])
898
- text = (
899
- bytearray([self.byte_decoder[c] for c in text])
900
- .decode("utf-8", errors="replace")
901
- .replace("</w>", " ")
902
- )
903
- return text
904
-
905
- def __call__(self, texts, context_length=None):
906
- if not context_length:
907
- context_length = self.context_length
908
-
909
- if isinstance(texts, str):
910
- texts = [texts]
911
-
912
- sot_token = self.encoder["<|startoftext|>"]
913
- eot_token = self.encoder["<|endoftext|>"]
914
- all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
915
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
916
-
917
- for i, tokens in enumerate(all_tokens):
918
- tokens = tokens[:context_length]
919
- result[i, : len(tokens)] = torch.tensor(tokens)
920
-
921
- if len(result) == 1:
922
- return result[0]
923
- return result
924
-
925
-
926
- class Normalize(nn.Module):
927
- def __init__(self, dim: int) -> None:
928
- super().__init__()
929
- self.dim = dim
930
-
931
- def forward(self, x):
932
- return torch.nn.functional.normalize(x, dim=self.dim, p=2)
933
-
934
-
935
- class LearnableLogitScaling(nn.Module):
936
- def __init__(
937
- self,
938
- logit_scale_init: float = 1 / 0.07,
939
- learnable: bool = True,
940
- max_logit_scale: float = 100,
941
- ) -> None:
942
- super().__init__()
943
- self.max_logit_scale = max_logit_scale
944
- self.logit_scale_init = logit_scale_init
945
- self.learnable = learnable
946
- log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
947
- if learnable:
948
- self.log_logit_scale = nn.Parameter(log_logit_scale)
949
- else:
950
- self.register_buffer("log_logit_scale", log_logit_scale)
951
-
952
- def forward(self, x):
953
- return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
954
-
955
- def extra_repr(self):
956
- st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
957
- return st
958
-
959
-
960
- class EinOpsRearrange(nn.Module):
961
- def __init__(self, rearrange_expr: str, **kwargs) -> None:
962
- super().__init__()
963
- self.rearrange_expr = rearrange_expr
964
- self.kwargs = kwargs
965
-
966
- def forward(self, x):
967
- assert isinstance(x, torch.Tensor)
968
- return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
969
-
970
-
971
- class IMUPreprocessor(VerboseNNModule):
972
- def __init__(
973
- self,
974
- kernel_size: int,
975
- imu_stem: PatchEmbedGeneric,
976
- embed_dim: int,
977
- img_size: List = (6, 2000),
978
- num_cls_tokens: int = 1,
979
- pos_embed_fn: Callable = None,
980
- init_param_style: str = "openclip",
981
- ) -> None:
982
- super().__init__()
983
- stem = imu_stem
984
- self.imu_stem = imu_stem
985
- self.embed_dim = embed_dim
986
- self.use_pos_embed = pos_embed_fn is not None
987
- self.num_cls_tokens = num_cls_tokens
988
- self.kernel_size = kernel_size
989
- self.pos_embed = nn.Parameter(
990
- torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
991
- )
992
-
993
- if self.num_cls_tokens > 0:
994
- self.cls_token = nn.Parameter(
995
- torch.zeros(1, self.num_cls_tokens, self.embed_dim)
996
- )
997
-
998
- self.init_parameters(init_param_style)
999
-
1000
- @torch.no_grad()
1001
- def init_parameters(self, init_param_style):
1002
- nn.init.normal_(self.pos_embed, std=0.01)
1003
-
1004
- if init_param_style == "openclip":
1005
- # OpenCLIP style initialization
1006
- scale = self.embed_dim ** -0.5
1007
-
1008
- if self.num_cls_tokens > 0:
1009
- nn.init.normal_(self.cls_token)
1010
- self.cls_token *= scale
1011
- elif init_param_style == "vit":
1012
- self.cls_token.data.fill_(0)
1013
- else:
1014
- raise ValueError(f"Unknown init {init_param_style}")
1015
-
1016
- def tokenize_input_and_cls_pos(self, input, stem):
1017
- # tokens is of shape B x L x D
1018
- tokens = stem.norm_layer(stem.proj(input))
1019
- assert tokens.ndim == 3
1020
- assert tokens.shape[2] == self.embed_dim
1021
- B = tokens.shape[0]
1022
- if self.num_cls_tokens > 0:
1023
- class_tokens = self.cls_token.expand(
1024
- B, -1, -1
1025
- ) # stole class_tokens impl from Phil Wang, thanks
1026
- tokens = torch.cat((class_tokens, tokens), dim=1)
1027
- if self.use_pos_embed:
1028
- tokens = tokens + self.pos_embed
1029
- return tokens
1030
-
1031
- def forward(self, imu):
1032
- # Patchify
1033
- imu = imu.unfold(
1034
- -1,
1035
- self.kernel_size,
1036
- self.kernel_size,
1037
- ).permute(0, 2, 1, 3)
1038
- imu = imu.reshape(imu.size(0), imu.size(1), -1)
1039
-
1040
- imu_tokens = self.tokenize_input_and_cls_pos(
1041
- imu,
1042
- self.imu_stem,
1043
- )
1044
-
1045
- return_dict = {
1046
- "trunk": {
1047
- "tokens": imu_tokens,
1048
- },
1049
- "head": {},
1050
- }
1051
- return return_dict
1052
-
1053
-
1054
- def cast_if_src_dtype(
1055
- tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
1056
- ):
1057
- updated = False
1058
- if tensor.dtype == src_dtype:
1059
- tensor = tensor.to(dtype=tgt_dtype)
1060
- updated = True
1061
- return tensor, updated
1062
-
1063
-
1064
- class QuickGELU(nn.Module):
1065
- # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
1066
- def forward(self, x: torch.Tensor):
1067
- return x * torch.sigmoid(1.702 * x)
1068
-
1069
-
1070
- class SelectElement(nn.Module):
1071
- def __init__(self, index) -> None:
1072
- super().__init__()
1073
- self.index = index
1074
-
1075
- def forward(self, x):
1076
- assert x.ndim >= 3
1077
- return x[:, self.index, ...]
1078
-
1079
-
1080
- class ReshapeSpatial(nn.Module):
1081
- def __init__(self, shape) -> None:
1082
- super().__init__()
1083
- self.h, self.w = shape
1084
-
1085
- def forward(self, x):
1086
- assert x.ndim >= 3
1087
- return x[:, 1:, ...].reshape(x.shape[0], self.h, self.w, -1), x[:, 0, :]
1088
-
1089
-
1090
- class ReshapeAudio(nn.Module):
1091
- def __init__(self, shape) -> None:
1092
- super().__init__()
1093
- self.h, self.w = shape
1094
-
1095
- def forward(self, x):
1096
- assert x.ndim == 3
1097
- return x[:, 1:, :].reshape(-1, 5, self.h, self.w, x.shape[-1]), x[:, 0, :]
1098
-
1099
-
1100
- class ApplyTwice(nn.Module):
1101
- def __init__(self, module) -> None:
1102
- super().__init__()
1103
- self.module = module
1104
-
1105
- def forward(self, pair):
1106
- return self.module(pair[0]), self.module(pair[1])
1107
-
1108
-
1109
- class SelectEOSAndProject(nn.Module):
1110
- """
1111
- Text Pooling used in OpenCLIP
1112
- """
1113
-
1114
- def __init__(self, proj: nn.Module) -> None:
1115
- super().__init__()
1116
- self.proj = proj
1117
-
1118
- def forward(self, x, seq_len):
1119
- assert x.ndim == 3
1120
- # x is of shape B x L x D
1121
- # take features from the eot embedding (eot_token is the highest number in each sequence)
1122
- x = x[torch.arange(x.shape[0]), seq_len]
1123
- x = self.proj(x)
1124
- return x
1125
-
1126
-
1127
- ModalityType = SimpleNamespace(
1128
- VISION="vision",
1129
- TEXT="text",
1130
- AUDIO="audio",
1131
- THERMAL="thermal",
1132
- DEPTH="depth",
1133
- IMU="imu",
1134
- )
1135
-
1136
-
1137
- class ImageBindModel(nn.Module):
1138
- def __init__(
1139
- self,
1140
- video_frames=2,
1141
- kernel_size=(2, 14, 14),
1142
- audio_kernel_size=16,
1143
- audio_stride=10,
1144
- out_embed_dim=768,
1145
- vision_embed_dim=1024,
1146
- vision_num_blocks=24,
1147
- vision_num_heads=16,
1148
- audio_embed_dim=768,
1149
- audio_num_blocks=12,
1150
- audio_num_heads=12,
1151
- audio_num_mel_bins=128,
1152
- audio_target_len=204,
1153
- audio_drop_path=0.1,
1154
- text_embed_dim=768,
1155
- text_num_blocks=12,
1156
- text_num_heads=12,
1157
- depth_embed_dim=384,
1158
- depth_kernel_size=16,
1159
- depth_num_blocks=12,
1160
- depth_num_heads=8,
1161
- depth_drop_path=0.0,
1162
- thermal_embed_dim=768,
1163
- thermal_kernel_size=16,
1164
- thermal_num_blocks=12,
1165
- thermal_num_heads=12,
1166
- thermal_drop_path=0.0,
1167
- imu_embed_dim=512,
1168
- imu_kernel_size=8,
1169
- imu_num_blocks=6,
1170
- imu_num_heads=8,
1171
- imu_drop_path=0.7,
1172
- ):
1173
- super().__init__()
1174
-
1175
- self.modality_preprocessors = self._create_modality_preprocessors(
1176
- video_frames,
1177
- vision_embed_dim,
1178
- kernel_size,
1179
- text_embed_dim,
1180
- audio_embed_dim,
1181
- audio_kernel_size,
1182
- audio_stride,
1183
- audio_num_mel_bins,
1184
- audio_target_len,
1185
- depth_embed_dim,
1186
- depth_kernel_size,
1187
- thermal_embed_dim,
1188
- thermal_kernel_size,
1189
- imu_embed_dim,
1190
- )
1191
-
1192
- self.modality_trunks = self._create_modality_trunks(
1193
- vision_embed_dim,
1194
- vision_num_blocks,
1195
- vision_num_heads,
1196
- text_embed_dim,
1197
- text_num_blocks,
1198
- text_num_heads,
1199
- audio_embed_dim,
1200
- audio_num_blocks,
1201
- audio_num_heads,
1202
- audio_drop_path,
1203
- depth_embed_dim,
1204
- depth_num_blocks,
1205
- depth_num_heads,
1206
- depth_drop_path,
1207
- thermal_embed_dim,
1208
- thermal_num_blocks,
1209
- thermal_num_heads,
1210
- thermal_drop_path,
1211
- imu_embed_dim,
1212
- imu_num_blocks,
1213
- imu_num_heads,
1214
- imu_drop_path,
1215
- )
1216
-
1217
- self.modality_heads = self._create_modality_heads(
1218
- out_embed_dim,
1219
- vision_embed_dim,
1220
- text_embed_dim,
1221
- audio_embed_dim,
1222
- depth_embed_dim,
1223
- thermal_embed_dim,
1224
- imu_embed_dim,
1225
- )
1226
-
1227
- self.modality_postprocessors = self._create_modality_postprocessors(
1228
- out_embed_dim
1229
- )
1230
-
1231
- def _create_modality_preprocessors(
1232
- self,
1233
- video_frames=2,
1234
- vision_embed_dim=1024,
1235
- kernel_size=(2, 14, 14),
1236
- text_embed_dim=768,
1237
- audio_embed_dim=768,
1238
- audio_kernel_size=16,
1239
- audio_stride=10,
1240
- audio_num_mel_bins=128,
1241
- audio_target_len=204,
1242
- depth_embed_dim=768,
1243
- depth_kernel_size=16,
1244
- thermal_embed_dim=768,
1245
- thermal_kernel_size=16,
1246
- imu_embed_dim=512,
1247
- ):
1248
- rgbt_stem = PatchEmbedGeneric(
1249
- proj_stem=[
1250
- PadIm2Video(pad_type="repeat", ntimes=2),
1251
- nn.Conv3d(
1252
- in_channels=3,
1253
- kernel_size=kernel_size,
1254
- out_channels=vision_embed_dim,
1255
- stride=kernel_size,
1256
- bias=False,
1257
- ),
1258
- ]
1259
- )
1260
- rgbt_preprocessor = RGBDTPreprocessor(
1261
- img_size=[3, video_frames, 224, 224],
1262
- num_cls_tokens=1,
1263
- pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
1264
- rgbt_stem=rgbt_stem,
1265
- depth_stem=None,
1266
- )
1267
-
1268
- text_preprocessor = TextPreprocessor(
1269
- context_length=77,
1270
- vocab_size=49408,
1271
- embed_dim=text_embed_dim,
1272
- causal_masking=True,
1273
- )
1274
-
1275
- audio_stem = PatchEmbedGeneric(
1276
- proj_stem=[
1277
- nn.Conv2d(
1278
- in_channels=1,
1279
- kernel_size=audio_kernel_size,
1280
- stride=audio_stride,
1281
- out_channels=audio_embed_dim,
1282
- bias=False,
1283
- ),
1284
- ],
1285
- norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
1286
- )
1287
- audio_preprocessor = AudioPreprocessor(
1288
- img_size=[1, audio_num_mel_bins, audio_target_len],
1289
- num_cls_tokens=1,
1290
- pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
1291
- audio_stem=audio_stem,
1292
- )
1293
-
1294
- # depth_stem = PatchEmbedGeneric(
1295
- # [
1296
- # nn.Conv2d(
1297
- # kernel_size=depth_kernel_size,
1298
- # in_channels=1,
1299
- # out_channels=depth_embed_dim,
1300
- # stride=depth_kernel_size,
1301
- # bias=False,
1302
- # ),
1303
- # ],
1304
- # norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
1305
- # )
1306
- #
1307
- # depth_preprocessor = RGBDTPreprocessor(
1308
- # img_size=[1, 224, 224],
1309
- # num_cls_tokens=1,
1310
- # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
1311
- # rgbt_stem=None,
1312
- # depth_stem=depth_stem,
1313
- # )
1314
- #
1315
- # thermal_stem = PatchEmbedGeneric(
1316
- # [
1317
- # nn.Conv2d(
1318
- # kernel_size=thermal_kernel_size,
1319
- # in_channels=1,
1320
- # out_channels=thermal_embed_dim,
1321
- # stride=thermal_kernel_size,
1322
- # bias=False,
1323
- # ),
1324
- # ],
1325
- # norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
1326
- # )
1327
- # thermal_preprocessor = ThermalPreprocessor(
1328
- # img_size=[1, 224, 224],
1329
- # num_cls_tokens=1,
1330
- # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
1331
- # thermal_stem=thermal_stem,
1332
- # )
1333
- #
1334
- # imu_stem = PatchEmbedGeneric(
1335
- # [
1336
- # nn.Linear(
1337
- # in_features=48,
1338
- # out_features=imu_embed_dim,
1339
- # bias=False,
1340
- # ),
1341
- # ],
1342
- # norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
1343
- # )
1344
- #
1345
- # imu_preprocessor = IMUPreprocessor(
1346
- # img_size=[6, 2000],
1347
- # num_cls_tokens=1,
1348
- # kernel_size=8,
1349
- # embed_dim=imu_embed_dim,
1350
- # pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
1351
- # imu_stem=imu_stem,
1352
- # )
1353
-
1354
- modality_preprocessors = {
1355
- ModalityType.VISION: rgbt_preprocessor,
1356
- ModalityType.TEXT: text_preprocessor,
1357
- ModalityType.AUDIO: audio_preprocessor,
1358
- # ModalityType.DEPTH: depth_preprocessor,
1359
- # ModalityType.THERMAL: thermal_preprocessor,
1360
- # ModalityType.IMU: imu_preprocessor,
1361
- }
1362
-
1363
- return nn.ModuleDict(modality_preprocessors)
1364
-
1365
- def _create_modality_trunks(
1366
- self,
1367
- vision_embed_dim=1024,
1368
- vision_num_blocks=24,
1369
- vision_num_heads=16,
1370
- text_embed_dim=768,
1371
- text_num_blocks=12,
1372
- text_num_heads=12,
1373
- audio_embed_dim=768,
1374
- audio_num_blocks=12,
1375
- audio_num_heads=12,
1376
- audio_drop_path=0.0,
1377
- depth_embed_dim=768,
1378
- depth_num_blocks=12,
1379
- depth_num_heads=12,
1380
- depth_drop_path=0.0,
1381
- thermal_embed_dim=768,
1382
- thermal_num_blocks=12,
1383
- thermal_num_heads=12,
1384
- thermal_drop_path=0.0,
1385
- imu_embed_dim=512,
1386
- imu_num_blocks=6,
1387
- imu_num_heads=8,
1388
- imu_drop_path=0.7,
1389
- ):
1390
- def instantiate_trunk(
1391
- embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
1392
- ):
1393
- return SimpleTransformer(
1394
- embed_dim=embed_dim,
1395
- num_blocks=num_blocks,
1396
- ffn_dropout_rate=0.0,
1397
- drop_path_rate=drop_path,
1398
- attn_target=partial(
1399
- MultiheadAttention,
1400
- embed_dim=embed_dim,
1401
- num_heads=num_heads,
1402
- bias=True,
1403
- add_bias_kv=add_bias_kv,
1404
- ),
1405
- pre_transformer_layer=nn.Sequential(
1406
- nn.LayerNorm(embed_dim, eps=1e-6)
1407
- if pre_transformer_ln
1408
- else nn.Identity(),
1409
- EinOpsRearrange("b l d -> l b d"),
1410
- ),
1411
- post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
1412
- )
1413
-
1414
- modality_trunks = {}
1415
- modality_trunks[ModalityType.VISION] = instantiate_trunk(
1416
- vision_embed_dim,
1417
- vision_num_blocks,
1418
- vision_num_heads,
1419
- pre_transformer_ln=True,
1420
- add_bias_kv=False,
1421
- drop_path=0.0,
1422
- )
1423
- modality_trunks[ModalityType.TEXT] = instantiate_trunk(
1424
- text_embed_dim,
1425
- text_num_blocks,
1426
- text_num_heads,
1427
- pre_transformer_ln=False,
1428
- add_bias_kv=False,
1429
- drop_path=0.0,
1430
- )
1431
- modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
1432
- audio_embed_dim,
1433
- audio_num_blocks,
1434
- audio_num_heads,
1435
- pre_transformer_ln=False,
1436
- add_bias_kv=True,
1437
- drop_path=audio_drop_path,
1438
- )
1439
- # modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
1440
- # depth_embed_dim,
1441
- # depth_num_blocks,
1442
- # depth_num_heads,
1443
- # pre_transformer_ln=False,
1444
- # add_bias_kv=True,
1445
- # drop_path=depth_drop_path,
1446
- # )
1447
- # modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
1448
- # thermal_embed_dim,
1449
- # thermal_num_blocks,
1450
- # thermal_num_heads,
1451
- # pre_transformer_ln=False,
1452
- # add_bias_kv=True,
1453
- # drop_path=thermal_drop_path,
1454
- # )
1455
- # modality_trunks[ModalityType.IMU] = instantiate_trunk(
1456
- # imu_embed_dim,
1457
- # imu_num_blocks,
1458
- # imu_num_heads,
1459
- # pre_transformer_ln=False,
1460
- # add_bias_kv=True,
1461
- # drop_path=imu_drop_path,
1462
- # )
1463
-
1464
- return nn.ModuleDict(modality_trunks)
1465
-
1466
- def _create_modality_heads(
1467
- self,
1468
- out_embed_dim,
1469
- vision_embed_dim,
1470
- text_embed_dim,
1471
- audio_embed_dim,
1472
- depth_embed_dim,
1473
- thermal_embed_dim,
1474
- imu_embed_dim,
1475
- ):
1476
- modality_heads = {}
1477
-
1478
- modality_heads[ModalityType.VISION] = nn.Sequential(
1479
- nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
1480
- SelectElement(index=0),
1481
- nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
1482
- )
1483
-
1484
- modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
1485
- proj=nn.Sequential(
1486
- nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
1487
- nn.Linear(text_embed_dim, out_embed_dim, bias=False),
1488
- )
1489
- )
1490
-
1491
- modality_heads[ModalityType.AUDIO] = nn.Sequential(
1492
- nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
1493
- SelectElement(index=0),
1494
- nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
1495
- )
1496
-
1497
- # modality_heads[ModalityType.DEPTH] = nn.Sequential(
1498
- # nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
1499
- # SelectElement(index=0),
1500
- # nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
1501
- # )
1502
- #
1503
- # modality_heads[ModalityType.THERMAL] = nn.Sequential(
1504
- # nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
1505
- # SelectElement(index=0),
1506
- # nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
1507
- # )
1508
- #
1509
- # modality_heads[ModalityType.IMU] = nn.Sequential(
1510
- # nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
1511
- # SelectElement(index=0),
1512
- # nn.Dropout(p=0.5),
1513
- # nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
1514
- # )
1515
-
1516
- return nn.ModuleDict(modality_heads)
1517
-
1518
- def _create_modality_postprocessors(self, out_embed_dim):
1519
- modality_postprocessors = {}
1520
-
1521
- modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
1522
- modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
1523
- Normalize(dim=-1), LearnableLogitScaling(learnable=True)
1524
- )
1525
- modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
1526
- Normalize(dim=-1),
1527
- LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
1528
- )
1529
- # modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
1530
- # Normalize(dim=-1),
1531
- # LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
1532
- # )
1533
- # modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
1534
- # Normalize(dim=-1),
1535
- # LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
1536
- # )
1537
- # modality_postprocessors[ModalityType.IMU] = nn.Sequential(
1538
- # Normalize(dim=-1),
1539
- # LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
1540
- # )
1541
-
1542
- return nn.ModuleDict(modality_postprocessors)
1543
-
1544
- def forward(self, inputs):
1545
- outputs = {}
1546
- for modality_key, modality_value in inputs.items():
1547
- reduce_list = (
1548
- modality_value.ndim >= 5
1549
- ) # Audio and Video inputs consist of multiple clips
1550
- if reduce_list:
1551
- B, S = modality_value.shape[:2]
1552
- modality_value = modality_value.reshape(
1553
- B * S, *modality_value.shape[2:]
1554
- )
1555
-
1556
- if modality_value is not None:
1557
- modality_value = self.modality_preprocessors[modality_key](
1558
- **{modality_key: modality_value}
1559
- )
1560
- trunk_inputs = modality_value["trunk"]
1561
- head_inputs = modality_value["head"]
1562
- modality_value = self.modality_trunks[modality_key](**trunk_inputs)
1563
- modality_value = self.modality_heads[modality_key](
1564
- modality_value, **head_inputs
1565
- )
1566
- modality_value = self.modality_postprocessors[modality_key](
1567
- modality_value
1568
- )
1569
-
1570
- if reduce_list:
1571
- modality_value = modality_value.reshape(B, S, -1)
1572
- modality_value = modality_value.mean(dim=1)
1573
-
1574
- outputs[modality_key] = modality_value
1575
-
1576
- return outputs
1577
-
1578
- def reconfigure_head(self, k, v):
1579
- if k == ModalityType.AUDIO:
1580
- return torch.nn.Sequential(v[0], v[2])
1581
- elif k == ModalityType.VISION:
1582
- return torch.nn.Sequential(v[0], v[2])
1583
- else:
1584
- return v
1585
-
1586
- def forward_features(self, inputs):
1587
- outputs = {}
1588
-
1589
- reconfigured_heads = {k: self.reconfigure_head(k, v) for k, v in self.modality_heads.items()}
1590
-
1591
- for modality_key, modality_value in inputs.items():
1592
- reduce_list = (
1593
- modality_value.ndim >= 5
1594
- ) # Audio and Video inputs consist of multiple clips
1595
- if reduce_list:
1596
- B, S = modality_value.shape[:2]
1597
- modality_value = modality_value.reshape(
1598
- B * S, *modality_value.shape[2:]
1599
- )
1600
-
1601
- if modality_value is not None:
1602
- modality_value = self.modality_preprocessors[modality_key](
1603
- **{modality_key: modality_value}
1604
- )
1605
- trunk_inputs = modality_value["trunk"]
1606
- head_inputs = modality_value["head"]
1607
- modality_value = self.modality_trunks[modality_key](**trunk_inputs)
1608
- modality_value = reconfigured_heads[modality_key](
1609
- modality_value, **head_inputs
1610
- )
1611
- modality_value = self.modality_postprocessors[modality_key](
1612
- modality_value
1613
- )
1614
- if modality_key == ModalityType.AUDIO:
1615
- modality_value = ReshapeAudio((12, 19))(modality_value)
1616
- elif modality_key == ModalityType.VISION:
1617
- modality_value = ReshapeSpatial((16, 16))(modality_value)
1618
-
1619
- outputs[modality_key] = modality_value
1620
-
1621
- return outputs
1622
-
1623
-
1624
- def imagebind_huge(output_path, pretrained=False):
1625
- model = ImageBindModel(
1626
- vision_embed_dim=1280,
1627
- vision_num_blocks=32,
1628
- vision_num_heads=16,
1629
- text_embed_dim=1024,
1630
- text_num_blocks=24,
1631
- text_num_heads=16,
1632
- out_embed_dim=1024,
1633
- audio_drop_path=0.1,
1634
- imu_drop_path=0.7,
1635
- )
1636
-
1637
- if pretrained:
1638
- path = os.path.join(output_path, 'models/imagebind_huge.pth')
1639
-
1640
- if not os.path.exists(path):
1641
- print(f"Downloading imagebind weights to {path} ...")
1642
- os.makedirs(os.path.dirname(path), exist_ok=True)
1643
- torch.hub.download_url_to_file(
1644
- "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
1645
- path,
1646
- progress=True,
1647
- )
1648
-
1649
- model.load_state_dict(torch.load(path), strict=False)
1650
-
1651
- return model
1652
-
1653
-
1654
- DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
1655
-
1656
-
1657
- def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
1658
- # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
1659
- waveform -= waveform.mean()
1660
- fbank = torchaudio.compliance.kaldi.fbank(
1661
- waveform,
1662
- htk_compat=True,
1663
- sample_frequency=sample_rate,
1664
- use_energy=False,
1665
- window_type="hanning",
1666
- num_mel_bins=num_mel_bins,
1667
- dither=0.0,
1668
- frame_length=25,
1669
- frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
1670
- )
1671
- # Convert to [mel_bins, num_frames] shape
1672
- fbank = fbank.transpose(0, 1)
1673
- # Pad to target_length
1674
- n_frames = fbank.size(1)
1675
- p = target_length - n_frames
1676
- # if p is too large (say >20%), flash a warning
1677
- if abs(p) / n_frames > 0.2:
1678
- logging.warning(
1679
- "Large gap between audio n_frames(%d) and "
1680
- "target_length (%d). Is the audio_target_length "
1681
- "setting correct?",
1682
- n_frames,
1683
- target_length,
1684
- )
1685
- # cut and pad
1686
- if p > 0:
1687
- fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
1688
- elif p < 0:
1689
- fbank = fbank[:, 0:target_length]
1690
- # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
1691
- # channel image
1692
- fbank = fbank.unsqueeze(0)
1693
- return fbank
1694
-
1695
-
1696
- def get_clip_timepoints(clip_sampler, duration):
1697
- # Read out all clips in this video
1698
- all_clips_timepoints = []
1699
- is_last_clip = False
1700
- end = 0.0
1701
- while not is_last_clip:
1702
- start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
1703
- all_clips_timepoints.append((start, end))
1704
- return all_clips_timepoints
1705
-
1706
-
1707
- def load_and_transform_vision_data(image_paths, device):
1708
- if image_paths is None:
1709
- return None
1710
-
1711
- image_ouputs = []
1712
- for image_path in image_paths:
1713
- data_transform = transforms.Compose(
1714
- [
1715
- transforms.Resize(
1716
- 224, interpolation=transforms.InterpolationMode.BICUBIC
1717
- ),
1718
- transforms.CenterCrop(224),
1719
- transforms.ToTensor(),
1720
- transforms.Normalize(
1721
- mean=(0.48145466, 0.4578275, 0.40821073),
1722
- std=(0.26862954, 0.26130258, 0.27577711),
1723
- ),
1724
- ]
1725
- )
1726
- with open(image_path, "rb") as fopen:
1727
- image = Image.open(fopen).convert("RGB")
1728
-
1729
- image = data_transform(image).to(device)
1730
- image_ouputs.append(image)
1731
- return torch.stack(image_ouputs, dim=0)
1732
-
1733
-
1734
- def load_and_transform_audio_data(
1735
- audio_paths,
1736
- device,
1737
- num_mel_bins=128,
1738
- target_length=204,
1739
- sample_rate=16000,
1740
- clip_duration=2,
1741
- clips_per_video=3,
1742
- mean=-4.268,
1743
- std=9.138,
1744
- ):
1745
- from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
1746
-
1747
- if audio_paths is None:
1748
- return None
1749
-
1750
- audio_outputs = []
1751
- clip_sampler = ConstantClipsPerVideoSampler(
1752
- clip_duration=clip_duration, clips_per_video=clips_per_video
1753
- )
1754
-
1755
- for audio_path in audio_paths:
1756
- waveform, sr = torchaudio.load(audio_path)
1757
- if sample_rate != sr:
1758
- waveform = torchaudio.functional.resample(
1759
- waveform, orig_freq=sr, new_freq=sample_rate
1760
- )
1761
- all_clips_timepoints = get_clip_timepoints(
1762
- clip_sampler, waveform.size(1) / sample_rate
1763
- )
1764
- all_clips = []
1765
- for clip_timepoints in all_clips_timepoints:
1766
- waveform_clip = waveform[
1767
- :,
1768
- int(clip_timepoints[0] * sample_rate): int(
1769
- clip_timepoints[1] * sample_rate
1770
- ),
1771
- ]
1772
- waveform_melspec = waveform2melspec(
1773
- waveform_clip, sample_rate, num_mel_bins, target_length
1774
- )
1775
- all_clips.append(waveform_melspec)
1776
-
1777
- normalize = transforms.Normalize(mean=mean, std=std)
1778
- all_clips = [normalize(ac).to(device) for ac in all_clips]
1779
-
1780
- all_clips = torch.stack(all_clips, dim=0)
1781
- audio_outputs.append(all_clips)
1782
-
1783
- return torch.stack(audio_outputs, dim=0)
1784
-
1785
-
1786
- class UnNormalize(object):
1787
- def __init__(self, mean, std):
1788
- self.mean = mean
1789
- self.std = std
1790
-
1791
- def __call__(self, image):
1792
- image2 = torch.clone(image)
1793
- for t, m, s in zip(image2, self.mean, self.std):
1794
- t.mul_(s).add_(m)
1795
- return image2
1796
-
1797
-
1798
- norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1799
- unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
1800
-
1801
-
1802
- class TorchPCA(object):
1803
-
1804
- def __init__(self, n_components):
1805
- self.n_components = n_components
1806
-
1807
- def fit(self, X):
1808
- self.mean_ = X.mean(dim=0)
1809
- unbiased = X - self.mean_.unsqueeze(0)
1810
- U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
1811
- self.components_ = V.T
1812
- self.singular_values_ = S
1813
- return self
1814
-
1815
- def transform(self, X):
1816
- t0 = X - self.mean_.unsqueeze(0)
1817
- projected = t0 @ self.components_.T
1818
- return projected
1819
-
1820
-
1821
- def pca(image_feats_list, dim=3, fit_pca=None):
1822
- # from sklearn.decomposition import PCA
1823
-
1824
- device = image_feats_list[0].device
1825
-
1826
- def flatten(tensor, target_size=None):
1827
- if target_size is not None and fit_pca is None:
1828
- F.interpolate(tensor, (target_size, target_size), mode="bilinear")
1829
- B, C, H, W = tensor.shape
1830
- return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
1831
-
1832
- if len(image_feats_list) > 1 and fit_pca is None:
1833
- target_size = image_feats_list[0].shape[2]
1834
- else:
1835
- target_size = None
1836
-
1837
- flattened_feats = []
1838
- for feats in image_feats_list:
1839
- flattened_feats.append(flatten(feats, target_size))
1840
- x = torch.cat(flattened_feats, dim=0)
1841
-
1842
- if fit_pca is None:
1843
- # fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy()))
1844
- fit_pca = TorchPCA(n_components=dim).fit(x)
1845
-
1846
- reduced_feats = []
1847
- for feats in image_feats_list:
1848
- # x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
1849
- x_red = fit_pca.transform(flatten(feats))
1850
- x_red -= x_red.min(dim=0, keepdim=True).values
1851
- x_red /= x_red.max(dim=0, keepdim=True).values
1852
- B, C, H, W = feats.shape
1853
- reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
1854
-
1855
- return reduced_feats, fit_pca
1856
-
1857
-
1858
- def my_load_audio(audio_file):
1859
- loaded_waveform, obs_sr = torchaudio.load(audio_file)
1860
- loaded_waveform = loaded_waveform[0]
1861
-
1862
- neg_waveform, neg_obs_sr = None, None
1863
- from data.AVDatasets import prep_waveform
1864
-
1865
- (waveform,
1866
- spectrogram,
1867
- audio_length,
1868
- total_length,
1869
- original_length,
1870
- mask,
1871
- pos_mask) = prep_waveform(
1872
- loaded_waveform,
1873
- obs_sr,
1874
- 10,
1875
- 128,
1876
- -4.268,
1877
- 9.138,
1878
- 16000,
1879
- True,
1880
- False,
1881
- False,
1882
- neg_waveform,
1883
- neg_obs_sr,
1884
- False,
1885
- )
1886
-
1887
- patch_size = 204
1888
- n_tiles = spectrogram.shape[1] // patch_size
1889
- assert n_tiles == 5
1890
-
1891
- patches = []
1892
- for i in range(n_tiles):
1893
- patches.append(spectrogram[:, i * patch_size:(i + 1) * patch_size, :])
1894
-
1895
- patches = torch.cat(patches, dim=0).permute(0, 2, 1).unsqueeze(1)
1896
- return patches
1897
-
1898
-
1899
- class ImageBindImageFeaturizer(nn.Module):
1900
-
1901
- def __init__(self, output_path, model=None):
1902
- super().__init__()
1903
- if model is not None:
1904
- self.model = model
1905
- else:
1906
- self.model = imagebind_huge(output_path, pretrained=True).cuda()
1907
-
1908
- def forward(self, image, include_cls):
1909
- inputs = {
1910
- ModalityType.VISION: image,
1911
- }
1912
-
1913
- patch_tokens, cls_tokens = self.model.forward_features(inputs)[ModalityType.VISION]
1914
- patch_tokens = patch_tokens.permute(0, 3, 1, 2)
1915
-
1916
- if include_cls:
1917
- return patch_tokens, cls_tokens
1918
- else:
1919
- return patch_tokens
1920
-
1921
-
1922
- class ImageBindAudioFeaturizer(nn.Module):
1923
-
1924
- def __init__(self, output_path, model=None):
1925
- super().__init__()
1926
- if model is not None:
1927
- self.model = model
1928
- else:
1929
- self.model = imagebind_huge(output_path, pretrained=True).cuda()
1930
-
1931
- def forward(self, spec, include_cls):
1932
-
1933
- patch_size = 204
1934
- n_tiles = spec.shape[2] // patch_size
1935
- assert n_tiles == 5
1936
-
1937
- patches = []
1938
- for i in range(n_tiles):
1939
- patches.append(spec[:, :, i * patch_size:(i + 1) * patch_size, :])
1940
-
1941
- patches = torch.cat(patches, dim=1).permute(0, 1, 3, 2).unsqueeze(2)
1942
-
1943
- inputs = {
1944
- ModalityType.AUDIO: patches,
1945
- }
1946
-
1947
- patch_tokens, cls_token = self.model.forward_features(inputs)[ModalityType.AUDIO]
1948
-
1949
- patch_tokens = patch_tokens.permute(0, 4, 2, 1, 3)
1950
- b, c, h, p, w = patch_tokens.shape
1951
- patch_tokens = patch_tokens.reshape(b, c, h, w * p)
1952
-
1953
- cls_token = cls_token.reshape(b, p, -1).mean(1)
1954
-
1955
- if include_cls:
1956
- return patch_tokens, cls_token
1957
- else:
1958
- return patch_tokens
1959
-
1960
-
1961
- if __name__ == "__main__":
1962
- image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
1963
- audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
1964
-
1965
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
1966
-
1967
- # Instantiate model
1968
- model = imagebind_huge("../../", pretrained=True)
1969
- model.eval()
1970
- model.to(device)
1971
-
1972
- audio_inputs = torch.cat([my_load_audio(af).unsqueeze(0) for af in audio_paths], dim=0).cuda()
1973
- # Load data
1974
- inputs = {
1975
- ModalityType.VISION: load_and_transform_vision_data(image_paths, device),
1976
- # ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, device, clip_duration=2, clips_per_video=5),
1977
- ModalityType.AUDIO: audio_inputs,
1978
-
1979
- }
1980
-
1981
- with torch.no_grad():
1982
- embeddings = model.forward_features(inputs)
1983
- cls_tokens = model.forward(inputs)
1984
-
1985
- audio_cls_token = embeddings["audio"][1].reshape(3, 5, -1).mean(1)
1986
-
1987
- sims1 = torch.einsum(
1988
- "bc,dc->bd",
1989
- embeddings["vision"][1],
1990
- audio_cls_token)
1991
-
1992
- print(torch.softmax(sims1, dim=1).cpu().numpy())
1993
- #
1994
- # sims2 = torch.einsum(
1995
- # "bc,dc->bd",
1996
- # embeddings["vision"].mean(1).mean(1),
1997
- # embeddings["audio"].mean(1).mean(1).mean(1)
1998
- # )
1999
- #
2000
- # print(torch.softmax(sims2, dim=1).cpu().numpy())
2001
- #
2002
- #
2003
- # img_num = 0
2004
- # img_feats = F.normalize(embeddings["vision"].permute(0, 3, 1, 2), dim=1)
2005
- # [red_img_feats], fit_pca = pca([img_feats])
2006
- #
2007
- # fig, axes = plt.subplots(2, 2, figsize=(4 * 2, 4 * 2))
2008
- # axes[0][0].imshow(unnorm(inputs["vision"][0].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
2009
- # axes[0][1].imshow(unnorm(inputs["vision"][1].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
2010
- # axes[1][0].imshow(red_img_feats[0].permute(1, 2, 0).detach().cpu())
2011
- # axes[1][1].imshow(red_img_feats[1].permute(1, 2, 0).detach().cpu())
2012
- # plt.tight_layout()
2013
- # plt.show()
2014
- #
2015
- audio_embs = F.normalize(embeddings["audio"][0], dim=-1)
2016
- b, n, h, w, c = audio_embs.shape
2017
-
2018
- audio_embs = audio_embs.permute(0, 4, 2, 1, 3).reshape(b, c, h, w * n)
2019
-
2020
- b, n, c, h, w = inputs["audio"].shape
2021
- audio_inputs = inputs["audio"].permute(0, 2, 3, 1, 4).reshape(b, c, h, w * n)
2022
-
2023
- print("here")
2024
-
2025
- for img_num in range(3):
2026
- [red_audio], fit_pca = pca([audio_embs[img_num].unsqueeze(0)])
2027
- fig, axes = plt.subplots(2, 1, figsize=(10 * 1, 4 * 2))
2028
- axes[0].imshow(audio_inputs[img_num, 0].detach().cpu())
2029
- axes[1].imshow(red_audio[0].permute(1, 2, 0).detach().cpu())
2030
- plt.tight_layout()
2031
- plt.show()
2032
-
2033
- print("here")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/featurizers/__init__.py DELETED
File without changes
DenseAV/denseav/plotting.py DELETED
@@ -1,244 +0,0 @@
1
- import os
2
- from collections import defaultdict
3
-
4
- import matplotlib.colors as mcolors
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import scipy.io.wavfile as wavfile
8
- import torch
9
- import torch.nn.functional as F
10
- import torchvision
11
- from moviepy.editor import VideoFileClip, AudioFileClip
12
- from base64 import b64encode
13
- from denseav.shared import pca
14
-
15
-
16
- def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
17
- """
18
- Writes video frames and audio to a specified path.
19
-
20
- Parameters:
21
- - video_frames: torch.Tensor of shape (num_frames, height, width, channels)
22
- - audio_array: torch.Tensor of shape (num_samples, num_channels)
23
- - video_fps: int, frames per second of the video
24
- - audio_fps: int, sample rate of the audio
25
- - output_path: str, path to save the final video with audio
26
- """
27
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
28
-
29
- temp_video_path = output_path.replace('.mp4', '_temp.mp4')
30
- temp_audio_path = output_path.replace('.mp4', '_temp_audio.wav')
31
- video_options = {
32
- 'crf': '23',
33
- 'preset': 'slow',
34
- 'bit_rate': '1000k'}
35
-
36
- if audio_array is not None:
37
- torchvision.io.write_video(
38
- filename=temp_video_path,
39
- video_array=video_frames,
40
- fps=video_fps,
41
- options=video_options
42
- )
43
-
44
- wavfile.write(temp_audio_path, audio_fps, audio_array.cpu().to(torch.float64).permute(1, 0).numpy())
45
- video_clip = VideoFileClip(temp_video_path)
46
- audio_clip = AudioFileClip(temp_audio_path)
47
- final_clip = video_clip.set_audio(audio_clip)
48
- final_clip.write_videofile(output_path, codec='libx264', verbose=False)
49
- os.remove(temp_video_path)
50
- os.remove(temp_audio_path)
51
- else:
52
- torchvision.io.write_video(
53
- filename=output_path,
54
- video_array=video_frames,
55
- fps=video_fps,
56
- options=video_options
57
- )
58
-
59
-
60
- def alpha_blend_layers(layers):
61
- blended_image = layers[0]
62
- for layer in layers[1:]:
63
- rgb1, alpha1 = blended_image[:, :3, :, :], blended_image[:, 3:4, :, :]
64
- rgb2, alpha2 = layer[:, :3, :, :], layer[:, 3:4, :, :]
65
- alpha_out = alpha2 + alpha1 * (1 - alpha2)
66
- rgb_out = (rgb2 * alpha2 + rgb1 * alpha1 * (1 - alpha2)) / alpha_out.clamp(min=1e-7)
67
- blended_image = torch.cat([rgb_out, alpha_out], dim=1)
68
- return (blended_image[:, :3] * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
69
-
70
-
71
- def _prep_sims_for_plotting(sim_by_head, frames):
72
- with torch.no_grad():
73
- results = defaultdict(list)
74
- n_frames, _, vh, vw = frames.shape
75
-
76
- sims = sim_by_head.max(dim=1).values
77
-
78
- n_audio_feats = sims.shape[-1]
79
- for frame_num in range(n_frames):
80
- selected_audio_feat = int((frame_num / n_frames) * n_audio_feats)
81
-
82
- selected_sim = F.interpolate(
83
- sims[frame_num, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
84
- size=(vh, vw),
85
- mode="bicubic")
86
-
87
- results["sims_all"].append(selected_sim)
88
-
89
- for head in range(sim_by_head.shape[1]):
90
- selected_sim = F.interpolate(
91
- sim_by_head[frame_num, head, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
92
- size=(vh, vw),
93
- mode="bicubic")
94
- results[f"sims_{head + 1}"].append(selected_sim)
95
-
96
- results = {k: torch.cat(v, dim=0) for k, v in results.items()}
97
- return results
98
-
99
-
100
- def get_plasma_with_alpha():
101
- plasma = plt.cm.plasma(np.linspace(0, 1, 256))
102
- alphas = np.linspace(0, 1, 256)
103
- plasma_with_alpha = np.zeros((256, 4))
104
- plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
105
- plasma_with_alpha[:, 3] = alphas
106
- return mcolors.ListedColormap(plasma_with_alpha)
107
-
108
-
109
- def get_inferno_with_alpha_2(alpha=0.5, k=30):
110
- k_fraction = k / 100.0
111
- custom_cmap = np.zeros((256, 4))
112
- threshold_index = int(k_fraction * 256)
113
- custom_cmap[:threshold_index, :3] = 0 # RGB values for black
114
- custom_cmap[:threshold_index, 3] = alpha # Alpha value
115
- remaining_inferno = plt.cm.inferno(np.linspace(0, 1, 256 - threshold_index))
116
- custom_cmap[threshold_index:, :3] = remaining_inferno[:, :3]
117
- custom_cmap[threshold_index:, 3] = alpha # Alpha value
118
- return mcolors.ListedColormap(custom_cmap)
119
-
120
-
121
- def get_inferno_with_alpha():
122
- plasma = plt.cm.inferno(np.linspace(0, 1, 256))
123
- alphas = np.linspace(0, 1, 256)
124
- plasma_with_alpha = np.zeros((256, 4))
125
- plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
126
- plasma_with_alpha[:, 3] = alphas
127
- return mcolors.ListedColormap(plasma_with_alpha)
128
-
129
-
130
- red_cmap = mcolors.LinearSegmentedColormap('RedMap', segmentdata={
131
- 'red': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
132
- 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
133
- 'blue': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
134
- 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
135
- })
136
-
137
- blue_cmap = mcolors.LinearSegmentedColormap('BlueMap', segmentdata={
138
- 'red': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
139
- 'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
140
- 'blue': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
141
- 'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
142
- })
143
-
144
-
145
- def plot_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
146
- prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
147
- n_frames, _, vh, vw = frames.shape
148
- sims_all = prepped_sims["sims_all"].clamp_min(0)
149
- sims_all -= sims_all.min()
150
- sims_all = sims_all / sims_all.max()
151
- cmap = get_inferno_with_alpha()
152
- layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
153
- layer2 = torch.tensor(cmap(sims_all.squeeze().detach().cpu())).permute(0, 3, 1, 2)
154
- write_video_with_audio(
155
- alpha_blend_layers([layer1, layer2]),
156
- audio,
157
- video_fps,
158
- audio_fps,
159
- output_filename)
160
-
161
-
162
- def plot_2head_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
163
- prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
164
- sims_1 = prepped_sims["sims_1"]
165
- sims_2 = prepped_sims["sims_2"]
166
-
167
- n_frames, _, vh, vw = frames.shape
168
-
169
- mask = sims_1 > sims_2
170
- sims_1 *= mask
171
- sims_2 *= (~mask)
172
-
173
- sims_1 = sims_1.clamp_min(0)
174
- sims_1 -= sims_1.min()
175
- sims_1 = sims_1 / sims_1.max()
176
-
177
- sims_2 = sims_2.clamp_min(0)
178
- sims_2 -= sims_2.min()
179
- sims_2 = sims_2 / sims_2.max()
180
-
181
- layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
182
- layer2_head1 = torch.tensor(red_cmap(sims_1.squeeze().detach().cpu())).permute(0, 3, 1, 2)
183
- layer2_head2 = torch.tensor(blue_cmap(sims_2.squeeze().detach().cpu())).permute(0, 3, 1, 2)
184
-
185
- write_video_with_audio(
186
- alpha_blend_layers([layer1, layer2_head1, layer2_head2]),
187
- audio,
188
- video_fps,
189
- audio_fps,
190
- output_filename)
191
-
192
-
193
- def plot_feature_video(image_feats,
194
- audio_feats,
195
- frames,
196
- audio,
197
- video_fps,
198
- audio_fps,
199
- video_filename,
200
- audio_filename):
201
- with torch.no_grad():
202
- image_feats_ = image_feats.cpu()
203
- audio_feats_ = audio_feats.cpu()
204
- [red_img_feats, red_audio_feats], _ = pca([
205
- image_feats_,
206
- audio_feats_, # .tile(image_feats_.shape[0], 1, 1, 1)
207
- ])
208
- _, _, vh, vw = frames.shape
209
- red_img_feats = F.interpolate(red_img_feats, size=(vh, vw), mode="bicubic")
210
- red_audio_feats = red_audio_feats[0].unsqueeze(0)
211
- red_audio_feats = F.interpolate(red_audio_feats, size=(50, red_img_feats.shape[0]), mode="bicubic")
212
-
213
- write_video_with_audio(
214
- (red_img_feats.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
215
- audio,
216
- video_fps,
217
- audio_fps,
218
- video_filename)
219
-
220
- red_audio_feats_expanded = red_audio_feats.tile(red_img_feats.shape[0], 1, 1, 1)
221
- red_audio_feats_expanded = F.interpolate(red_audio_feats_expanded, scale_factor=6, mode="bicubic")
222
- for i in range(red_img_feats.shape[0]):
223
- center_index = i * 6
224
- min_index = max(center_index - 2, 0)
225
- max_index = min(center_index + 2, red_audio_feats_expanded.shape[-1])
226
- red_audio_feats_expanded[i, :, :, min_index:max_index] = 1
227
-
228
- write_video_with_audio(
229
- (red_audio_feats_expanded.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
230
- audio,
231
- video_fps,
232
- audio_fps,
233
- audio_filename)
234
-
235
-
236
- def display_video_in_notebook(path):
237
- from IPython.display import HTML, display
238
- mp4 = open(path, 'rb').read()
239
- data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
240
- display(HTML("""
241
- <video width=400 controls>
242
- <source src="%s" type="video/mp4">
243
- </video>
244
- """ % data_url))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/saved_models.py DELETED
@@ -1,262 +0,0 @@
1
- import os
2
- import re
3
- from os.path import join
4
-
5
- import torch
6
-
7
-
8
-
9
- def get_latest(name, checkpoint_dir, extra_args=None):
10
- if extra_args is None:
11
- extra_args = dict()
12
- files = os.listdir(join(checkpoint_dir, name))
13
- steps = torch.tensor([int(f.split("step=")[-1].split(".")[0]) for f in files])
14
- selected = files[steps.argmax()]
15
- return dict(
16
- chkpt_name=os.path.join(name, selected),
17
- extra_args=extra_args)
18
-
19
-
20
- DS_PARAM_REGEX = r'_forward_module\.(.+)'
21
-
22
-
23
- def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
24
- '''
25
- Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
26
- in parameters which are improperly loaded by the DeepSpeed conversion utility.
27
- deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
28
- pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
29
- placed in the same directory as the DeepSpeed checkpoint directory with the same name but
30
- a .pt extension.
31
- Returns: path to the converted checkpoint.
32
- '''
33
- from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
34
-
35
-
36
- if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
37
- raise ValueError(
38
- 'args.ckpt_dir should point to the checkpoint directory'
39
- ' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
40
- )
41
-
42
- # Convert state dict to PyTorch format
43
- if not pl_ckpt_path:
44
- pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt
45
-
46
- if not os.path.exists(pl_ckpt_path):
47
- convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)
48
-
49
- # Patch in missing parameters that failed to be converted by DeepSpeed utility
50
- pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
51
- torch.save(pl_ckpt, pl_ckpt_path)
52
-
53
- return pl_ckpt_path
54
-
55
-
56
- def get_optim_files(checkpoint_dir):
57
- files = sorted([f for f in os.listdir(checkpoint_dir) if "optim" in f])
58
- return [join(checkpoint_dir, f) for f in files]
59
-
60
-
61
- def get_model_state_file(checkpoint_dir, zero_stage):
62
- f = [f for f in os.listdir(checkpoint_dir) if "model_states" in f][0]
63
- return join(checkpoint_dir, f)
64
-
65
-
66
- def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
67
- '''
68
- Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
69
- into the fp32 state dict.
70
- deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
71
- fp32_ckpt_path: Path to the reconstructed
72
- '''
73
- from pytorch_lightning.utilities.deepspeed import ds_checkpoint_dir
74
-
75
-
76
- # This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
77
- checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
78
- optim_files = get_optim_files(checkpoint_dir)
79
- optim_state = torch.load(optim_files[0], map_location='cpu')
80
- zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
81
- deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)
82
-
83
- # Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
84
- ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
85
- ds_sd = ds_ckpt['module']
86
-
87
- fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
88
- fp32_sd = fp32_ckpt['state_dict']
89
-
90
- for k, v in ds_sd.items():
91
- try:
92
- match = re.match(DS_PARAM_REGEX, k)
93
- param_name = match.group(1)
94
- except:
95
- print(f'Failed to extract parameter from DeepSpeed key {k}')
96
- continue
97
-
98
- v = v.to(torch.float32)
99
- if param_name not in fp32_sd:
100
- print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
101
- fp32_sd[param_name] = v
102
- else:
103
- assert torch.allclose(v, fp32_sd[param_name].to(torch.float32), atol=1e-2)
104
-
105
- return fp32_ckpt
106
-
107
-
108
- def get_version_and_step(f, i):
109
- step = f.split("step=")[-1].split(".")[0]
110
- if "-v" in step:
111
- [step, version] = step.split("-v")
112
- else:
113
- step, version = step, 0
114
-
115
- return int(version), int(step), i
116
-
117
-
118
- def get_latest_ds(name, extra_args=None):
119
- if extra_args is None:
120
- extra_args = dict()
121
- files = os.listdir(f"../checkpoints/{name}")
122
- latest = sorted([get_version_and_step(f, i) for i, f in enumerate(files)], reverse=True)[0]
123
- selected = files[latest[-1]]
124
- # print(f"Selecting file: {selected}")
125
- ds_chkpt = join(name, selected)
126
- reg_chkpt = join(name + "_fp32", selected)
127
- reg_chkpt_path = join("../checkpoints", reg_chkpt)
128
- if not os.path.exists(reg_chkpt_path):
129
- os.makedirs(os.path.dirname(reg_chkpt_path), exist_ok=True)
130
- print(f"Checkpoint {reg_chkpt} does not exist, converting from deepspeed")
131
- convert_deepspeed_checkpoint(join("../checkpoints", ds_chkpt), reg_chkpt_path)
132
- return dict(
133
- chkpt_name=reg_chkpt,
134
- extra_args=extra_args)
135
-
136
-
137
- def get_all_models_in_dir(name, checkpoint_dir, extra_args=None):
138
- ret = {}
139
- for model_dir in os.listdir(join(checkpoint_dir, name)):
140
- full_name = f"{name}/{model_dir}/train"
141
- # print(f'"{full_name}",')
142
- ret[full_name] = get_latest(full_name, checkpoint_dir, extra_args)
143
- return ret
144
-
145
-
146
- def saved_model_dict(checkpoint_dir):
147
- model_info = {
148
-
149
- **get_all_models_in_dir(
150
- "9-5-23-mixed",
151
- checkpoint_dir,
152
- extra_args=dict(
153
- mixup_weight=0.0,
154
- sim_use_cls=False,
155
- audio_pool_width=1,
156
- memory_buffer_size=0,
157
- loss_leak=0.0)
158
- ),
159
-
160
- **get_all_models_in_dir(
161
- "1-23-24-rebuttal-heads",
162
- checkpoint_dir,
163
- extra_args=dict(
164
- loss_leak=0.0)
165
- ),
166
-
167
- **get_all_models_in_dir(
168
- "11-8-23",
169
- checkpoint_dir,
170
- extra_args=dict(loss_leak=0.0)),
171
-
172
- **get_all_models_in_dir(
173
- "10-30-23-3",
174
- checkpoint_dir,
175
- extra_args=dict(loss_leak=0.0)),
176
-
177
- "davenet": dict(
178
- chkpt_name=None,
179
- extra_args=dict(
180
- audio_blur=1,
181
- image_model_type="davenet",
182
- image_aligner_type=None,
183
- audio_model_type="davenet",
184
- audio_aligner_type=None,
185
- audio_input="davenet_spec",
186
- use_cached_embs=False,
187
- dropout=False,
188
- sim_agg_heads=1,
189
- nonneg_sim=False,
190
- audio_lora=False,
191
- image_lora=False,
192
- norm_vectors=False,
193
- ),
194
- data_args=dict(
195
- use_cached_embs=False,
196
- use_davenet_spec=True,
197
- override_target_length=20,
198
- audio_model_type="davenet",
199
- ),
200
- ),
201
-
202
- "cavmae": dict(
203
- chkpt_name=None,
204
- extra_args=dict(
205
- audio_blur=1,
206
- image_model_type="cavmae",
207
- image_aligner_type=None,
208
- audio_model_type="cavmae",
209
- audio_aligner_type=None,
210
- audio_input="spec",
211
- use_cached_embs=False,
212
- sim_agg_heads=1,
213
- dropout=False,
214
- nonneg_sim=False,
215
- audio_lora=False,
216
- image_lora=False,
217
- norm_vectors=False,
218
- learn_audio_cls=False,
219
- sim_agg_type="cavmae",
220
- ),
221
- data_args=dict(
222
- use_cached_embs=False,
223
- use_davenet_spec=True,
224
- audio_model_type="cavmae",
225
- override_target_length=10,
226
- ),
227
- ),
228
-
229
- "imagebind": dict(
230
- chkpt_name=None,
231
- extra_args=dict(
232
- audio_blur=1,
233
- image_model_type="imagebind",
234
- image_aligner_type=None,
235
- audio_model_type="imagebind",
236
- audio_aligner_type=None,
237
- audio_input="spec",
238
- use_cached_embs=False,
239
- sim_agg_heads=1,
240
- dropout=False,
241
- nonneg_sim=False,
242
- audio_lora=False,
243
- image_lora=False,
244
- norm_vectors=False,
245
- learn_audio_cls=False,
246
- sim_agg_type="imagebind",
247
- ),
248
- data_args=dict(
249
- use_cached_embs=False,
250
- use_davenet_spec=True,
251
- audio_model_type="imagebind",
252
- override_target_length=10,
253
- ),
254
- ),
255
-
256
- }
257
-
258
- model_info["denseav_language"] = model_info["10-30-23-3/places_base/train"]
259
- model_info["denseav_sound"] = model_info["11-8-23/hubert_1h_asf_cls_full_image_train_small_lr/train"]
260
- model_info["denseav_2head"] = model_info["1-23-24-rebuttal-heads/mixed-2h/train"]
261
-
262
- return model_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/shared.py DELETED
@@ -1,739 +0,0 @@
1
- import random
2
- from collections import defaultdict, deque
3
- from typing import Any
4
-
5
- import math
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import torch
9
- import torch.distributed as dist
10
- import torch.nn.functional as F
11
- import torchaudio
12
- import torchvision.transforms as T
13
- from PIL import Image
14
- from torch.utils.data import Dataset
15
- from torchaudio.functional import resample
16
-
17
-
18
- class UnNormalize(object):
19
- def __init__(self, mean, std):
20
- self.mean = mean
21
- self.std = std
22
-
23
- def __call__(self, image):
24
- image2 = torch.clone(image)
25
- for t, m, s in zip(image2, self.mean, self.std):
26
- t.mul_(s).add_(m)
27
- return image2
28
-
29
-
30
- class SliceDataset(Dataset):
31
-
32
- def __init__(self, ds, start, end):
33
- self.ds = ds
34
- self.start = start
35
- self.end = end
36
-
37
- def __len__(self):
38
- return self.end - self.start
39
-
40
- def __getitem__(self, item):
41
- return self.ds[item + self.start]
42
-
43
-
44
- class SubsetDataset(Dataset):
45
-
46
- def __init__(self, ds, subset):
47
- self.ds = ds
48
- self.subset = subset
49
-
50
- def __len__(self):
51
- return len(self.subset)
52
-
53
- def __getitem__(self, item):
54
- return self.ds[self.subset[item]]
55
-
56
-
57
- norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
58
- unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
59
-
60
-
61
- def crop_to_divisor(x, patch_size):
62
- if len(x.shape) == 3:
63
- C, H, W = x.shape
64
- return x[:, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))]
65
- elif len(x.shape) == 4:
66
- B, C, H, W = x.shape
67
- return x[:, :, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))]
68
- else:
69
- raise ValueError("x should have 3 or 4 dimensions")
70
-
71
-
72
- def _remove_axes(ax):
73
- ax.xaxis.set_major_formatter(plt.NullFormatter())
74
- ax.yaxis.set_major_formatter(plt.NullFormatter())
75
- ax.set_xticks([])
76
- ax.set_yticks([])
77
-
78
-
79
- def remove_axes(axes):
80
- if len(axes.shape) == 2:
81
- for ax1 in axes:
82
- for ax in ax1:
83
- _remove_axes(ax)
84
- else:
85
- for ax in axes:
86
- _remove_axes(ax)
87
-
88
-
89
- def get_image_featurizer(name, token_type="key", **kwargs):
90
- name = name.lower()
91
-
92
- if name == "vit":
93
- from denseav.featurizers.DINO import DINOFeaturizer
94
- patch_size = 16
95
- model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type)
96
- dim = 384
97
- elif name == "dino16":
98
- from denseav.featurizers.DINO import DINOFeaturizer
99
- patch_size = 16
100
- model = DINOFeaturizer("dino_vits16", patch_size, token_type)
101
- dim = 384
102
- elif name == "dino8":
103
- from denseav.featurizers.DINO import DINOFeaturizer
104
- patch_size = 8
105
- model = DINOFeaturizer("dino_vits8", patch_size, token_type)
106
- dim = 384
107
- elif name == "clip":
108
- from denseav.featurizers.CLIP import CLIPFeaturizer
109
- patch_size = 16
110
- model = CLIPFeaturizer()
111
- dim = 512
112
- elif name == "cavmae":
113
- from denseav.featurizers.CAVMAE import CAVMAEImageFeaturizer
114
- model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
115
- dim = 768
116
- patch_size = 16
117
- elif name == "fnac":
118
- from denseav.featurizers.FNACAVL import FNACImageFeaturizer
119
- model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
120
- dim = 512
121
- patch_size = 16
122
- elif name == "imagebind":
123
- from denseav.featurizers.ImageBind import ImageBindImageFeaturizer
124
- model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
125
- dim = 1024
126
- patch_size = 16
127
- elif name == "resnet50":
128
- from torchvision import models
129
- model = models.resnet50(pretrained=True)
130
- model = torch.nn.Sequential(*list(model.children())[:-2])
131
- patch_size = 1
132
- dim = 2048
133
- elif name == "davenet":
134
- from fdenseav.eaturizers.DAVENet import DavenetImageFeaturizer
135
- model = DavenetImageFeaturizer()
136
- patch_size = 1
137
- dim = 1024
138
- elif name == "dinov2":
139
- from denseav.featurizers.DINOv2 import DINOv2Featurizer
140
- model = DINOv2Featurizer()
141
- patch_size = 14
142
- dim = 768
143
- else:
144
- raise ValueError("unknown model: {}".format(name))
145
- return model, patch_size, dim
146
-
147
-
148
- def get_audio_featurizer(name, **kwargs):
149
- if name == "davenet":
150
- from denseav.featurizers.DAVENet import DavenetAudioFeaturizer
151
- model = DavenetAudioFeaturizer()
152
- dim = 1024
153
- elif name == "dino8":
154
- model, _, dim = get_image_featurizer("dino8")
155
- elif name == "hubert":
156
- from denseav.featurizers.Hubert import Hubert
157
- model = Hubert()
158
- dim = 1024
159
- elif name == "cavmae":
160
- from denseav.featurizers.CAVMAE import CAVMAEAudioFeaturizer
161
- model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
162
- dim = 768
163
- elif name == "imagebind":
164
- from denseav.featurizers.ImageBind import ImageBindAudioFeaturizer
165
- model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
166
- dim = 1024
167
- elif name == "audiomae":
168
- from denseav.featurizers.AudioMAE import AudioMAE
169
- model = AudioMAE(kwargs["output_root"], False)
170
- dim = 768
171
- elif name == "audiomae-finetuned":
172
- from denseav.featurizers.AudioMAE import AudioMAE
173
- model = AudioMAE(kwargs["output_root"], True)
174
- dim = 768
175
- else:
176
- raise ValueError("Unknown audio model type")
177
-
178
- return model, dim
179
-
180
-
181
- def load_img(image_path, transform):
182
- return transform(Image.open(image_path)).unsqueeze(0)
183
-
184
-
185
- def pytorch_to_pil(tensor):
186
- return Image.fromarray((unnorm(tensor).permute(0, 2, 3, 1).cpu() * 255)
187
- .clamp(0, 255).to(torch.uint8).detach().numpy()[0])
188
-
189
-
190
- def _get_random_window(waveform, mask, min_size, max_size):
191
- effective_size = mask.sum().to(torch.int64)
192
- if effective_size <= min_size:
193
- return waveform, mask
194
- else:
195
- window_size = min(torch.randint(low=min_size, high=min(effective_size, max_size), size=()), waveform.shape[0])
196
- if window_size == waveform.shape[0]:
197
- window_start = 0
198
- else:
199
- window_start = torch.randint(low=0, high=effective_size - window_size, size=())
200
-
201
- new_waveform = torch.zeros_like(waveform)
202
- new_mask = torch.zeros_like(mask)
203
- new_waveform[window_start:window_start + window_size] = waveform[window_start:window_start + window_size]
204
- new_mask[window_start:window_start + window_size] = mask[window_start:window_start + window_size]
205
- return new_waveform, new_mask
206
-
207
-
208
- def _splice_clips(clip1, clip2, loc, easing_size):
209
- assert loc >= 0 and loc < len(clip1), "Invalid location"
210
- assert easing_size > 0 and easing_size <= len(clip2), "Invalid easing size"
211
-
212
- try:
213
- assert loc + clip2.shape[0] < clip1.shape[0]
214
- except Exception as e:
215
- print(loc, clip2.shape[0], clip1.shape[0])
216
- raise e
217
-
218
- # Split clip1 into three parts: before splice, easing region, after splice
219
- before_splice = clip1[:loc]
220
- after_splice = clip1[loc + clip2.shape[0]:]
221
-
222
- # Compute the fading weights for the easing region
223
- # fade_in_weights = torch.cos(torch.linspace(1, 0, easing_size, device=clip1.device))
224
- fade_in_weights = 0.5 * (1 + torch.cos(math.pi * torch.linspace(0, 1, easing_size)))
225
- fade_out_weights = 1 - fade_in_weights
226
-
227
- clip1_ease = torch.cat([
228
- fade_in_weights,
229
- torch.zeros(clip2.shape[0] - easing_size * 2),
230
- fade_out_weights,
231
- ])
232
-
233
- mask = torch.cat([torch.ones(loc), clip1_ease, torch.ones(clip1.shape[0] - (loc + clip2.shape[0]))])
234
-
235
- # Apply fading weights to clip1 and clip2 within the easing region
236
- splice = clip1_ease * clip1[loc:loc + clip2.shape[0]] + (1 - clip1_ease) * clip2
237
-
238
- # Concatenate all parts back together
239
- spliced_clip = torch.cat((before_splice, splice, after_splice))
240
-
241
- return spliced_clip, mask
242
-
243
-
244
- def _generate_random_subset(waveform, low, high):
245
- length = len(waveform)
246
-
247
- # If waveform is smaller than low or has zero length, return unmodified
248
- if length < low or length == 0:
249
- return waveform
250
-
251
- # Generate random start index within valid range
252
- start = random.randint(0, length - low)
253
-
254
- # Generate random subset size within valid range
255
- subset_size = random.randint(low, min(high, length - start))
256
-
257
- # Extract the random subset from the waveform
258
- subset = waveform[start: start + subset_size]
259
-
260
- return subset
261
-
262
-
263
- def level_audio(waveform):
264
- waveform -= waveform.mean()
265
- waveform /= waveform.abs.max().valus.clamp_min(.0001)
266
- return waveform
267
-
268
-
269
- def prep_waveform(waveform,
270
- obs_sr,
271
- target_length,
272
- spec_mel_bins,
273
- spec_mean,
274
- spec_std,
275
- sample_rate,
276
- return_spec,
277
- random_clip,
278
- extra_audio_masking,
279
- neg_waveform,
280
- neg_obs_sr,
281
- audio_level,
282
- audio_aug,
283
- ):
284
- if obs_sr != sample_rate:
285
- waveform = resample(waveform, obs_sr, sample_rate)
286
- if audio_level:
287
- waveform = level_audio(waveform)
288
-
289
- if neg_obs_sr is not None and neg_obs_sr != sample_rate:
290
- neg_waveform = resample(neg_waveform, neg_obs_sr, sample_rate)
291
- if audio_level:
292
- neg_waveform = level_audio(neg_waveform)
293
-
294
- if neg_obs_sr is not None: # and random.random() > .5:
295
- neg_waveform_clip = _generate_random_subset(neg_waveform, sample_rate, sample_rate * 4)
296
- if waveform.shape[0] - neg_waveform_clip.shape[0] > 0:
297
- start = random.randint(0, waveform.shape[0] - neg_waveform_clip.shape[0] - 1)
298
- easing = max(int(neg_waveform_clip.shape[0] * 1 / 4), sample_rate // 2)
299
- easing = min(int(neg_waveform_clip.shape[0] * 1 / 2), easing)
300
- waveform, pos_mask = _splice_clips(waveform, neg_waveform_clip, start, easing_size=easing)
301
- else:
302
- waveform, pos_mask = waveform, torch.ones_like(waveform)
303
- else:
304
- waveform, pos_mask = waveform, torch.ones_like(waveform)
305
-
306
- mask = torch.ones_like(waveform)
307
- original_length = waveform.shape[0]
308
-
309
- if target_length == 10:
310
- target_samples = 164200 # Result is 1024 after spec
311
- else:
312
- target_samples = int(target_length * sample_rate)
313
-
314
- padding = target_samples - original_length
315
-
316
- if padding > 0:
317
- p = torch.nn.ZeroPad2d((0, padding))
318
- waveform = p(waveform)
319
- mask = p(mask)
320
- pos_mask = p(pos_mask)
321
- else:
322
- if random_clip:
323
- start = torch.randint(0, waveform.shape[0] - target_samples, size=())
324
- else:
325
- start = 0
326
- end = start + target_samples
327
- waveform = waveform[start:end]
328
- mask = mask[start:end]
329
- pos_mask = pos_mask[start:end]
330
-
331
- audio_length = min(original_length, target_samples)
332
- total_length = target_samples
333
-
334
- if extra_audio_masking:
335
- min_size = sample_rate // 2
336
- max_size = total_length
337
- if original_length > min_size and random.random() > .5:
338
- waveform, mask = _get_random_window(waveform, mask, min_size, max_size)
339
-
340
- if audio_aug:
341
- import torchaudio_augmentations as AA
342
- from torchvision.transforms import RandomApply, Compose
343
-
344
- transform = Compose([
345
- RandomApply([AA.PolarityInversion()], p=0.5),
346
- RandomApply([AA.Noise(min_snr=0.001, max_snr=0.005)], p=0.2),
347
- RandomApply([AA.Gain()], p=0.2),
348
- RandomApply([AA.HighLowPass(sample_rate=sample_rate)], p=0.2),
349
- RandomApply([AA.PitchShift(n_samples=waveform.shape[-1], sample_rate=sample_rate)], p=0.2),
350
- RandomApply([AA.Reverb(sample_rate=sample_rate)], p=0.2)
351
- ])
352
- waveform = transform(waveform.unsqueeze(0)).squeeze(0)
353
-
354
- if return_spec:
355
- spectrogram = torchaudio.compliance.kaldi.fbank(
356
- waveform.unsqueeze(0) - waveform.mean(),
357
- htk_compat=True,
358
- sample_frequency=sample_rate,
359
- use_energy=False,
360
- window_type='hanning',
361
- num_mel_bins=spec_mel_bins,
362
- dither=0.0,
363
- frame_shift=10)
364
-
365
- spectrogram = ((spectrogram - spec_mean) / spec_std).unsqueeze(0)
366
- else:
367
- spectrogram = None
368
-
369
- if mask.mean() < .04:
370
- print(f"Bad entry: {mask.mean()}")
371
-
372
- return waveform, spectrogram, audio_length, total_length, original_length, mask, pos_mask
373
-
374
-
375
- class ToTargetTensor(object):
376
- def __call__(self, target):
377
- return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
378
-
379
-
380
- def show_heatmap(ax,
381
- image,
382
- heatmap,
383
- cmap="bwr",
384
- color=False,
385
- center=False,
386
- show_negative=False,
387
- cax=None,
388
- vmax=None,
389
- vmin=None):
390
- frame = []
391
-
392
- if color:
393
- frame.append(ax.imshow(image))
394
- else:
395
- bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140])
396
- bw = np.ones_like(image) * np.expand_dims(bw, -1)
397
- frame.append(ax.imshow(bw))
398
-
399
- if center:
400
- heatmap -= heatmap.mean()
401
-
402
- if not show_negative:
403
- heatmap = heatmap.clamp_min(0)
404
-
405
- heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \
406
- .squeeze(0).squeeze(0)
407
-
408
- if vmax is None:
409
- vmax = np.abs(heatmap).max()
410
- if vmin is None:
411
- vmin = -vmax
412
-
413
- hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=vmin)
414
- if cax is not None:
415
- plt.colorbar(hm, cax=cax, orientation='vertical')
416
-
417
- frame.extend([hm])
418
- return frame
419
-
420
-
421
- class TorchPCA(object):
422
-
423
- def __init__(self, n_components):
424
- self.n_components = n_components
425
-
426
- def fit(self, X):
427
- self.mean_ = X.mean(dim=0)
428
- unbiased = X - self.mean_.unsqueeze(0)
429
- U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
430
- self.components_ = V.T
431
- self.singular_values_ = S
432
- return self
433
-
434
- def transform(self, X):
435
- t0 = X - self.mean_.unsqueeze(0)
436
- projected = t0 @ self.components_.T
437
- return projected
438
-
439
-
440
- def pca(image_feats_list, dim=3, fit_pca=None):
441
- device = image_feats_list[0].device
442
-
443
- def flatten(tensor, target_size=None):
444
- if target_size is not None and fit_pca is None:
445
- F.interpolate(tensor, (target_size, target_size), mode="bilinear")
446
- B, C, H, W = tensor.shape
447
- return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
448
-
449
- if len(image_feats_list) > 1 and fit_pca is None:
450
- target_size = image_feats_list[0].shape[2]
451
- else:
452
- target_size = None
453
-
454
- flattened_feats = []
455
- for feats in image_feats_list:
456
- flattened_feats.append(flatten(feats, target_size))
457
- x = torch.cat(flattened_feats, dim=0)
458
-
459
- if fit_pca is None:
460
- # fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy()))
461
- fit_pca = TorchPCA(n_components=dim).fit(x)
462
-
463
- reduced_feats = []
464
- for feats in image_feats_list:
465
- # x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
466
- x_red = fit_pca.transform(flatten(feats))
467
- x_red -= x_red.min(dim=0, keepdim=True).values
468
- x_red /= x_red.max(dim=0, keepdim=True).values
469
- B, C, H, W = feats.shape
470
- reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
471
-
472
- return reduced_feats, fit_pca
473
-
474
-
475
- def merge_col(fig, axes, col):
476
- gs = axes[0, col].get_gridspec()
477
- for ax in axes[:, col]:
478
- ax.remove()
479
- return fig.add_subplot(gs[:, col])
480
-
481
-
482
- def visualize_av_features(
483
- audio,
484
- video,
485
- feat_a,
486
- feat_v,
487
- att_a,
488
- n_frames,
489
- norm_before_pca=True,
490
- axes=None,
491
- fig=None,
492
- modify_fig=True,
493
- video_time=0,
494
- fit_pca=None
495
- ):
496
- assert (len(audio.shape) == 3) # C, F, T
497
- assert (len(video.shape) == 4) # T, C, H, W
498
- assert (len(feat_a.shape) == 2) # C, T
499
- assert (len(feat_v.shape) == 4) # T, C, H, W
500
- assert (len(att_a.shape) == 2) # F, T
501
-
502
- ac, af, at = audio.shape
503
- fac, fat = feat_a.shape
504
-
505
- if modify_fig:
506
- if axes is None:
507
- fig, axes = plt.subplots(3, 3, figsize=(5 * 3, 5))
508
- fig.tight_layout()
509
-
510
- bigax1 = merge_col(fig, axes, 0)
511
- bigax2 = merge_col(fig, axes, 1)
512
- _remove_axes(bigax1)
513
- _remove_axes(bigax2)
514
- remove_axes(axes[:, 2])
515
- else:
516
- bigax1 = fig.axes[-2]
517
- bigax2 = fig.axes[-1]
518
-
519
- frame_v = unnorm(video).permute(0, 2, 3, 1).detach().cpu()
520
- frame_v -= frame_v.min()
521
- frame_v /= frame_v.max()
522
-
523
- frame_a = audio.detach().cpu()
524
- frame_a -= frame_a.min()
525
- frame_a /= frame_a.max()
526
-
527
- if norm_before_pca:
528
- [red_feat_v], fit_pca = pca([F.normalize(feat_v, dim=1)], fit_pca=fit_pca)
529
- [red_feat_a], _ = pca([F.normalize(feat_a.unsqueeze(0).unsqueeze(-1), dim=1)], fit_pca=fit_pca)
530
- else:
531
- [red_feat_v], fit_pca = pca([feat_v], fit_pca=fit_pca)
532
- [red_feat_a], _ = pca([feat_a.unsqueeze(0).unsqueeze(-1)], fit_pca=fit_pca)
533
-
534
- red_feat_v = red_feat_v.permute(0, 2, 3, 1).detach().cpu()
535
- red_feat_a = red_feat_a.permute(0, 2, 3, 1)[0].detach().cpu()
536
-
537
- if red_feat_a.shape[0] == 1:
538
- new_height = int((frame_a.shape[0] / frame_a.shape[1]) * red_feat_a.shape[1])
539
- red_feat_a = torch.broadcast_to(
540
- red_feat_a, (new_height, red_feat_a.shape[1], red_feat_a.shape[2]))
541
- plt_att_a = torch.broadcast_to(att_a, (new_height, att_a.shape[1]))
542
- else:
543
- plt_att_a = att_a
544
-
545
- frac_signal = n_frames / fat
546
- n_at = int(at * frac_signal)
547
-
548
- return [bigax1.imshow(frame_v[video_time]),
549
- bigax2.imshow(red_feat_v[video_time]),
550
- axes[0, 2].imshow(frame_a[:, :n_at]),
551
- axes[0, 2].set_title("Spectrogram"),
552
- axes[1, 2].imshow(red_feat_a[:, :n_frames]),
553
- axes[1, 2].set_title("Audio Features"),
554
- axes[2, 2].imshow(plt_att_a[:, :n_frames], vmin=0),
555
- axes[2, 2].set_title("Audio Attention")], fig, fit_pca
556
-
557
-
558
- def create_label_tensor(labels, starts, ends, max_time, n_steps):
559
- assert isinstance(starts, torch.Tensor)
560
- assert isinstance(ends, torch.Tensor)
561
-
562
- ends[ends < 0] = max_time
563
- fps = n_steps / max_time
564
- times = (torch.arange(0, n_steps, device=labels.device, dtype=torch.float32) + .5) / fps
565
- after_start = starts.unsqueeze(1) <= times.unsqueeze(0)
566
- before_end = ends.unsqueeze(1) >= times.unsqueeze(0)
567
- # Find when you are inside of a word
568
- in_word = (after_start * before_end)
569
- # Find which word you are inside of
570
- word_to_use = in_word.to(torch.float32).argmax(0)
571
- # Get the label for that word, or mask out the label if in no word
572
- final_labels = labels[word_to_use] * in_word.any(0).reshape(-1, 1, 1)
573
- return final_labels
574
-
575
-
576
- def generate_subset(n, batch, seed=0):
577
- np.random.seed(seed)
578
- return np.random.permutation(n)[:batch]
579
-
580
-
581
- def channel_blur(t, window=5, std_dev=1):
582
- tb, tc, th, tw = t.shape
583
- x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
584
- k = torch.exp((-x ** 2 / (2 * std_dev ** 2)))
585
- k = k / k.sum()
586
- pad = window // 2
587
- t_pad = F.pad(t, [0, 0, 0, 0, pad, pad], mode="replicate")
588
- tpb, tpc, tph, tpw = t_pad.shape
589
- flattened_t = t_pad.permute(0, 2, 3, 1).reshape(tpb * tph * tpw, 1, -1)
590
- return F.conv1d(flattened_t, k.reshape(1, 1, window)).reshape(tpb, tph, tpw, tc).permute(0, 3, 1, 2)
591
-
592
-
593
- def time_blur(t, window=5, std_dev=1):
594
- tb, tc, tt = t.shape
595
- with torch.no_grad():
596
- x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
597
- k = torch.exp((-x ** 2 / (2 * std_dev ** 2)))
598
- k = k / k.sum()
599
- k = k.reshape(1, 1, window).detach()
600
- pad = window // 2
601
- t_pad = F.pad(t, [pad, pad], mode="replicate")
602
- return F.conv1d(t_pad.reshape(tb * tc, 1, -1), k).reshape(tb, tc, tt)
603
-
604
-
605
- def create_model_from_cfg(clazz, cfg, extra_args):
606
- import inspect
607
- expected_args = inspect.getfullargspec(clazz.__init__).args[1:]
608
- new_args = {k: v for k, v in {**cfg, **extra_args}.items() if k in expected_args}
609
- return clazz(**new_args)
610
-
611
-
612
- def load_trained_model(chkpt_dir, extra_args, strict=True):
613
- from train_av_alignment import LitAVAligner
614
- model = LitAVAligner.load_from_checkpoint(chkpt_dir, **extra_args, strict=strict).cuda()
615
- return model
616
-
617
-
618
- def flatten(l):
619
- return [item for sublist in l for item in sublist]
620
-
621
-
622
- def flatten_preds(preds):
623
- results = {}
624
- for k in preds[0].keys():
625
- if k == "caption_labels":
626
- continue
627
- if isinstance(preds[0][k], torch.Tensor):
628
- results[k] = torch.cat([p[k] for p in preds], dim=0)
629
- if "caption" in preds[0]:
630
- results["caption"] = flatten([p["caption"] for p in preds])
631
-
632
- if "metadata" in preds[0]:
633
- results["frame_files"] = flatten([list(p["metadata"]["frame_files"][0]) for p in preds])
634
- results["audio_file"] = flatten([list(p["metadata"]["audio_file"]) for p in preds])
635
- results["id"] = flatten([list(p["metadata"]["id"]) for p in preds])
636
- results["index"] = torch.tensor(flatten([list(p["metadata"]["index"]) for p in preds]))
637
-
638
- return results
639
-
640
-
641
- def batch(iterable, n=1):
642
- l = len(iterable)
643
- for ndx in range(0, l, n):
644
- yield iterable[ndx:min(ndx + n, l)]
645
-
646
-
647
- class GatherLayer(torch.autograd.Function):
648
- """Gather tensors from all process, supporting backward propagation."""
649
-
650
- @staticmethod
651
- def jvp(ctx: Any, *grad_inputs: Any) -> Any:
652
- pass
653
-
654
- @staticmethod
655
- def forward(ctx, inputs):
656
- ctx.save_for_backward(inputs)
657
- output = [torch.zeros_like(inputs) for _ in range(dist.get_world_size())]
658
- dist.all_gather(output, inputs)
659
- return tuple(output)
660
-
661
- @staticmethod
662
- def backward(ctx, *grads):
663
- (inputs,) = ctx.saved_tensors
664
- grad_out = torch.zeros_like(inputs)
665
- grad_out[:] = grads[dist.get_rank()]
666
- return grad_out
667
-
668
-
669
- class RollingAvg:
670
-
671
- def __init__(self, length, nonzero=False):
672
- self.length = length
673
- self.nonzero = nonzero
674
- self.metrics = defaultdict(lambda: deque(maxlen=self.length))
675
-
676
- def add(self, name, metric):
677
- if self.nonzero and metric == 0:
678
- return
679
- if isinstance(metric, torch.Tensor):
680
- metric = metric.detach()
681
-
682
- self.metrics[name].append(metric)
683
-
684
- def get(self, name):
685
- with torch.no_grad():
686
- return torch.tensor(list(self.metrics[name])).mean()
687
-
688
- def get_all(self):
689
- return {k: self.get(k) for k in self.metrics.keys()}
690
-
691
- def add_all(self, values):
692
- for k, v in values.items():
693
- self.add(k, v)
694
-
695
- def logall(self, log_func):
696
- for k in self.metrics.keys():
697
- log_func(k, self.get(k))
698
-
699
-
700
- def gaussian_kernel(k, sigma):
701
- kernel = torch.tensor([math.exp(-0.5 * (x - (k // 2)) ** 2 / sigma ** 2) for x in range(k)], dtype=torch.float32)
702
- kernel /= kernel.sum() # Normalize the kernel
703
- return kernel
704
-
705
-
706
- def blur_dim(t, window=5, std_dev=1, dim=-1):
707
- shape = t.shape
708
- n_dims = len(shape)
709
-
710
- # Create the Gaussian kernel
711
- with torch.no_grad():
712
- x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
713
- k = torch.exp(-x ** 2 / (2 * std_dev ** 2))
714
- k = k / k.sum()
715
- k = k.view(1, 1, window).detach()
716
-
717
- # Calculate padding
718
- pad = window // 2
719
-
720
- # Move the target dimension to the end
721
- permute_order = list(range(n_dims))
722
- permute_order.append(permute_order.pop(dim))
723
- t_permuted = t.permute(permute_order)
724
-
725
- # Flatten all dimensions except the last one
726
- new_shape = (-1, t_permuted.size(-1))
727
- t_flattened = t_permuted.reshape(new_shape)
728
-
729
- # Pad the tensor
730
- t_padded = F.pad(t_flattened.unsqueeze(1), (pad, pad), mode="replicate")
731
-
732
- # Apply convolution
733
- blurred = F.conv1d(t_padded, k)
734
-
735
- # Reshape back to original
736
- blurred = blurred.squeeze(1).reshape(*t_permuted.shape)
737
- blurred = blurred.permute([permute_order.index(i) for i in range(n_dims)])
738
-
739
- return blurred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/denseav/train.py DELETED
@@ -1,1213 +0,0 @@
1
- import os
2
- from collections import deque
3
- from itertools import combinations
4
- from os.path import join
5
-
6
- import hydra
7
- import numpy as np
8
- import pytorch_lightning as pl
9
- import torch
10
- import torch.distributed as dist
11
- import torch.nn.functional as F
12
- from omegaconf import DictConfig, OmegaConf
13
- from peft import get_peft_model, LoraConfig
14
- from pytorch_lightning import Trainer
15
- from pytorch_lightning import seed_everything
16
- from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
17
- from pytorch_lightning.loggers import TensorBoardLogger
18
- from pytorch_lightning.utilities import grad_norm
19
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR
20
- from torchmetrics.functional.classification import binary_average_precision
21
-
22
- from huggingface_hub import PyTorchModelHubMixin
23
-
24
- from denseav.aggregators import get_aggregator
25
- from denseav.aligners import get_aligner, ProgressiveGrowing
26
- from denseav.constants import *
27
- from denseav.data.AVDatasets import AVDataModule
28
- from denseav.shared import flatten_preds, GatherLayer, \
29
- get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg
30
-
31
- torch.multiprocessing.set_sharing_strategy('file_system')
32
-
33
-
34
- def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor):
35
- mask = (true_indices == samples).to(torch.int64)
36
- n = mask.shape[0]
37
-
38
- if not mask.any():
39
- return samples
40
- else:
41
- new_samples = torch.randint(0, n, size=(n,), device=true_indices.device)
42
- comb_samples = mask * new_samples + (1 - mask) * samples
43
- return _imposter_indices_helper(true_indices, comb_samples)
44
-
45
-
46
- def imposter_indices(n, device):
47
- return _imposter_indices_helper(
48
- torch.arange(0, n, device=device),
49
- torch.randint(0, n, size=(n,), device=device))
50
-
51
-
52
- def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type):
53
- max_t = audio_outputs.shape[-1]
54
- oh = F.one_hot(n_frames - 1, num_classes=max_t)
55
- audio_mask = 1 - torch.cumsum(oh, dim=1)
56
- audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype)
57
-
58
- full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs)
59
- expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1)
60
-
61
- if sim_type.endswith("mi"):
62
- offset = 10 * (full_sim.max() - full_sim.min())
63
- full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values
64
-
65
- if sim_type.startswith("mi"):
66
- full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values
67
-
68
- if sim_type.endswith("sa"):
69
- full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True)
70
-
71
- return full_sim.mean(dim=[1, 2, 3])
72
-
73
-
74
- def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.):
75
- """
76
- Computes the triplet margin ranking loss for each anchor image/caption pair
77
- The impostor image/caption is randomly sampled from the minibatch
78
- """
79
- assert (image_outputs.dim() == 4)
80
- assert (audio_outputs.dim() == 3)
81
- n = image_outputs.size(0)
82
- imp_ind_i = imposter_indices(n, image_outputs.device)
83
- imp_ind_a = imposter_indices(n, image_outputs.device)
84
- true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type)
85
- imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type)
86
- imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type)
87
- a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0)
88
- i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0)
89
- return (a2i_loss + i2a_loss).mean() / 2
90
-
91
-
92
- class SimilarityCalibrator(torch.nn.Module):
93
-
94
- def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False):
95
- super().__init__()
96
- self.max_w = max_w
97
- self.min_w = min_w
98
- self.w = torch.nn.Parameter(torch.tensor([cal_init]).log())
99
-
100
- self.use_bias = use_bias
101
- if self.use_bias:
102
- self.b = torch.nn.Parameter(torch.tensor([0.0]))
103
-
104
- self.subtract_mean = subtract_mean
105
-
106
- def get_w(self):
107
- return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w)
108
-
109
- def forward(self, x):
110
- sims = self.get_w() * x
111
-
112
- if self.use_bias:
113
- sims = sims + self.b
114
-
115
- if self.subtract_mean:
116
- return sims - sims.mean()
117
- else:
118
- return sims
119
-
120
-
121
- class SpatialDropout(torch.nn.Module):
122
-
123
- def __init__(self, p, *args, **kwargs):
124
- super().__init__(*args, **kwargs)
125
- self.p = p
126
-
127
- def forward(self, x):
128
- b, c, h, w = x.shape
129
- dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p
130
-
131
- if self.training:
132
- return x * dropout
133
- else:
134
- return x
135
-
136
-
137
- class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]):
138
- def __init__(self,
139
- code_dim,
140
- image_model_type,
141
- image_model_token_type,
142
- image_aligner_type,
143
- image_pool_width,
144
- audio_model_type,
145
- audio_aligner_type,
146
- audio_pool_width,
147
- audio_lora,
148
- audio_lora_rank,
149
- image_lora,
150
- image_lora_rank,
151
- gradient_clipping,
152
- learn_audio_cls,
153
- silence_l1,
154
- silence_l2,
155
- tv_weight,
156
- nonneg_sim,
157
- nonneg_pressure,
158
- pretrain_lr,
159
- lr,
160
- lr_warmup,
161
- lr_schedule,
162
- lr_cycle_length,
163
- optimizer,
164
- gather_tensors,
165
- sim_agg_type,
166
- sim_agg_heads,
167
- sim_use_cls,
168
- disentangle_weight,
169
- norm_vectors,
170
- cal_init,
171
- cal_balance_weight,
172
- loss_type,
173
- loss_margin,
174
- mask_silence,
175
- finetune_image_model,
176
- finetune_audio_model,
177
- use_cached_embs,
178
- output_root,
179
- neg_audio,
180
- neg_audio_weight,
181
- head_agg,
182
- adaptive_clipping,
183
- specialization_weight,
184
- spatial_dropout,
185
- channel_dropout,
186
- mixup_weight,
187
- memory_buffer_size,
188
- loss_leak,
189
- ):
190
- super().__init__()
191
-
192
- self.code_dim = code_dim
193
- self.image_model_type = image_model_type
194
- self.image_model_token_type = image_model_token_type
195
- self.image_aligner_type = image_aligner_type
196
- self.image_pool_width = image_pool_width
197
- self.audio_model_type = audio_model_type
198
- self.audio_aligner_type = audio_aligner_type
199
- self.audio_pool_width = audio_pool_width
200
-
201
- self.gradient_clipping = gradient_clipping
202
- self.learn_audio_cls = learn_audio_cls
203
- self.silence_l1 = silence_l1
204
- self.silence_l2 = silence_l2
205
-
206
- self.tv_weight = tv_weight
207
- self.nonneg_sim = nonneg_sim
208
- self.nonneg_pressure = nonneg_pressure
209
- self.pretrain_lr = pretrain_lr
210
- self.lr = lr
211
- self.lr_warmup = lr_warmup
212
- self.lr_schedule = lr_schedule
213
- self.lr_cycle_length = lr_cycle_length
214
- self.optimizer = optimizer
215
- self.gather_tensors = gather_tensors
216
- self.sim_agg_type = sim_agg_type
217
- self.sim_agg_heads = sim_agg_heads
218
- self.sim_use_cls = sim_use_cls
219
- self.disentangle_weight = disentangle_weight
220
-
221
- self.norm_vectors = norm_vectors
222
- self.cal_init = cal_init
223
- self.cal_balance_weight = cal_balance_weight
224
- self.loss_type = loss_type
225
- self.loss_margin = loss_margin
226
- self.mask_silence = mask_silence
227
- self.finetune_image_model = finetune_image_model
228
- self.finetune_audio_model = finetune_audio_model
229
- self.use_cached_embs = use_cached_embs
230
- self.output_root = output_root
231
- self.audio_lora = audio_lora
232
- self.audio_lora_rank = audio_lora_rank
233
- self.image_lora = image_lora
234
- self.image_lora_rank = image_lora_rank
235
- self.neg_audio = neg_audio
236
- self.neg_audio_weight = neg_audio_weight
237
- self.head_agg = head_agg
238
-
239
- self.adaptive_clipping = adaptive_clipping
240
- self.specialization_weight = specialization_weight
241
- self.spatial_dropout = spatial_dropout
242
- self.channel_dropout = channel_dropout
243
- self.mixup_weight = mixup_weight
244
-
245
- self.memory_buffer_size = memory_buffer_size
246
- self.memory_buffer = deque(maxlen=self.memory_buffer_size)
247
- self.loss_leak = loss_leak
248
-
249
- if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
250
- self.audio_input = "spec"
251
- elif self.audio_model_type == "davenet":
252
- self.audio_input = "davenet_spec"
253
- elif self.audio_model_type == "fnac":
254
- self.audio_input = "fnac_spec"
255
- else:
256
- self.audio_input = "audio"
257
-
258
- extra_model_args = dict(output_root=output_root)
259
-
260
- self.image_model, _, self.image_feat_dim = get_image_featurizer(
261
- image_model_type, token_type=self.image_model_token_type, **extra_model_args)
262
-
263
- self.image_model.eval()
264
- if not self.finetune_image_model:
265
- for param in self.image_model.parameters():
266
- param.requires_grad = False
267
-
268
- if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}:
269
- extra_model_args["model"] = self.image_model.model
270
-
271
- if use_cached_embs:
272
- _, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
273
- else:
274
- self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
275
-
276
- self.audio_model.eval()
277
- if not self.finetune_audio_model:
278
- for param in self.audio_model.parameters():
279
- param.requires_grad = False
280
-
281
- if self.image_lora:
282
- if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}:
283
- target_modules = ["qkv"]
284
- elif self.image_model_type == "clip":
285
- target_modules = ["out_proj"]
286
- elif self.image_model_type == "imagebind":
287
- target_modules = ["out_proj", "fc1", "fc2"]
288
- else:
289
- target_modules = ["q", "k", "v"]
290
-
291
- peft_config = LoraConfig(
292
- target_modules=target_modules,
293
- inference_mode=False,
294
- r=image_lora_rank,
295
- lora_alpha=32,
296
- lora_dropout=0.1
297
- )
298
- self.image_model = get_peft_model(self.image_model, peft_config)
299
- self.image_model.print_trainable_parameters()
300
-
301
- if self.audio_lora:
302
- if self.audio_model_type == "hubert":
303
- target_modules = ["q_proj", "k_proj", "v_proj"]
304
- else:
305
- target_modules = ["q", "k", "v"]
306
-
307
- peft_config = LoraConfig(
308
- inference_mode=False,
309
- target_modules=target_modules,
310
- r=audio_lora_rank,
311
- lora_alpha=32,
312
- lora_dropout=0.1
313
- )
314
- self.audio_model = get_peft_model(self.audio_model, peft_config)
315
- self.audio_model.print_trainable_parameters()
316
-
317
- shared_aligner_args = dict(out_dim=self.code_dim)
318
-
319
- self.audio_aligner = get_aligner(
320
- self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args)
321
- self.image_aligner = get_aligner(
322
- self.image_aligner_type, self.image_feat_dim, **shared_aligner_args)
323
-
324
- if self.loss_type == "nce":
325
- self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False)
326
- else:
327
- self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True)
328
-
329
- if self.learn_audio_cls:
330
- self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim))
331
-
332
- if self.spatial_dropout > 0.0:
333
- self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout)
334
-
335
- if self.channel_dropout > 0.0:
336
- self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout)
337
-
338
- self.sim_agg = get_aggregator(
339
- self.sim_agg_type,
340
- self.nonneg_sim,
341
- self.mask_silence,
342
- self.sim_agg_heads,
343
- self.head_agg,
344
- self.sim_use_cls,
345
- dim=self.image_feat_dim
346
- )
347
-
348
- self.hparams_logged = False
349
- self.rolling_avg = RollingAvg(50)
350
- self.grad_avg = RollingAvg(50, nonzero=True)
351
-
352
- self.save_hyperparameters()
353
-
354
- def set_full_train(self, full_train):
355
- self.full_train = full_train
356
-
357
- def prep_feats(self, feats, is_audio):
358
-
359
- if not is_audio and self.training and self.image_pool_width > 1:
360
- feats = torch.nn.AvgPool2d(self.image_pool_width)(feats)
361
-
362
- if is_audio and self.training and self.audio_pool_width > 1:
363
- feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats)
364
-
365
- if self.norm_vectors:
366
- feats = F.normalize(feats, dim=1)
367
-
368
- return feats
369
-
370
- def on_before_optimizer_step(self, optimizer, optimizer_idx):
371
- norms = grad_norm(self, norm_type=2)
372
- avg_grads = self.grad_avg.get_all()
373
- params = {
374
- f"grad_2.0_norm/{name}": p
375
- for name, p in self.named_parameters()
376
- if p.grad is not None
377
- }
378
-
379
- if self.adaptive_clipping:
380
- for k in norms.keys():
381
- if k in params:
382
- avg_grad = max(avg_grads.get(k, norms[k]), 1e-5)
383
- if self.global_step > 10 and norms[k] > avg_grad * 5:
384
- print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}")
385
- torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5)
386
- norms[k] = avg_grad * 5
387
-
388
- if norms[k] > self.gradient_clipping:
389
- # print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}")
390
- torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping)
391
-
392
- # self.grad_avg.add_all(norms)
393
- # self.log_dict(norms)
394
-
395
- def interpolate_mask(self, mask, target_length, discrete):
396
- b, t = mask.shape
397
-
398
- mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \
399
- .reshape(b, target_length)
400
-
401
- if discrete:
402
- mask = mask > 0.01
403
- sums = mask.sum(1)
404
- all_zeros = torch.where(sums == 0)[0]
405
- if len(all_zeros) > 0:
406
- print("Fixing a bad mask")
407
- for entry in all_zeros:
408
- mask[entry, torch.randint(0, target_length - 1, size=())] = True
409
- else:
410
- return mask
411
- return mask
412
-
413
- def forward_audio(self, batch):
414
- if self.use_cached_embs:
415
- audio_feats = batch["audio_emb"]
416
- if "audio_cls" in batch:
417
- audio_cls = batch["audio_cls"]
418
- else:
419
- audio_cls = None
420
- else:
421
- audio = batch[self.audio_input]
422
-
423
- if self.full_train:
424
- audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
425
- else:
426
- with torch.no_grad():
427
- audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
428
-
429
- mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio)
430
- pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio)
431
-
432
- if self.learn_audio_cls:
433
- assert audio_cls is None
434
- audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1]))
435
-
436
- aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls)
437
-
438
- if self.channel_dropout > 0.0:
439
- aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats)
440
-
441
- aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True)
442
- audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True)
443
- audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False)
444
-
445
- ret = {
446
- AUDIO_MASK: audio_mask,
447
- AUDIO_POS_MASK: audio_pos_mask,
448
- AUDIO_FEATS: aligned_audio_feats,
449
- }
450
-
451
- if aligned_audio_cls is not None:
452
- ret[AUDIO_CLS] = aligned_audio_cls
453
-
454
- return ret
455
-
456
- # @autocast(device_type="cuda", enabled=False)
457
- def forward_image(self, batch, max_batch_size=None):
458
-
459
- with torch.no_grad():
460
- image = batch[IMAGE_INPUT]
461
- b, nf, c, h, w = image.shape
462
- image = image.reshape(b * nf, c, h, w)
463
-
464
- if max_batch_size is None:
465
- max_batch_size = image.shape[0]
466
-
467
- chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)]
468
-
469
- all_image_feats = []
470
- all_image_cls = []
471
-
472
- for chunk in chunks:
473
- if self.full_train:
474
- image_feats, image_cls = self.image_model(chunk, include_cls=True)
475
- else:
476
- with torch.no_grad():
477
- image_feats, image_cls = self.image_model(chunk, include_cls=True)
478
-
479
- aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls)
480
-
481
- all_image_feats.append(aligned_image_feats)
482
- all_image_cls.append(aligned_image_cls)
483
-
484
- # Stitch the chunks back together
485
- aligned_image_feats = torch.cat(all_image_feats, dim=0)
486
- aligned_image_cls = torch.cat(all_image_cls, dim=0)
487
-
488
- if self.channel_dropout > 0.0:
489
- aligned_image_feats = self.channel_dropout_layer(aligned_image_feats)
490
-
491
- if self.spatial_dropout > 0.0:
492
- aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats)
493
-
494
- aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False)
495
- ret = {IMAGE_FEATS: aligned_image_feats}
496
-
497
- if IMAGE_MASK in batch:
498
- with torch.no_grad():
499
- mask = batch[IMAGE_MASK]
500
- mask = mask.reshape(b * nf, 1, h, w)
501
- b, c, h, w = aligned_image_feats.shape
502
- mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w))
503
- ret[IMAGE_MASK] = mask
504
-
505
- if aligned_image_cls is not None:
506
- ret[IMAGE_CLS] = aligned_image_cls
507
-
508
- return ret
509
-
510
- def forward(self, batch):
511
- audio_feat_dict = self.forward_audio(batch)
512
- image_feat_dict = self.forward_image(batch)
513
- return {**image_feat_dict, **audio_feat_dict}
514
-
515
- def contrast_loss(self, sims):
516
- b = sims.shape[0]
517
- sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin
518
- sims_1 = sims
519
- sims_2 = sims.permute(1, 0)
520
-
521
- if self.loss_leak > 0.0:
522
- id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
523
- label_mask = id * (1 - self.loss_leak)
524
- label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1)
525
- label_mask /= label_mask.sum(dim=1, keepdim=True)
526
- else:
527
- label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
528
-
529
- labels = torch.arange(0, sims.shape[0], device=sims.device)
530
- self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean())
531
- self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean())
532
-
533
- if self.loss_type == "margin":
534
- margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0)
535
- margin_loss = margin_loss_tensor.mean()
536
- self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean())
537
- self.rolling_avg.add(f"loss/margin", margin_loss)
538
- return margin_loss
539
- elif self.loss_type == "ce":
540
- ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \
541
- 1 / 2 * F.cross_entropy(sims_2, labels)
542
- self.rolling_avg.add(f"loss/ce", ce_loss)
543
- return ce_loss
544
- elif self.loss_type == "bce":
545
- bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten())
546
- self.rolling_avg.add(f"loss/bce", bce_loss)
547
- return bce_loss
548
- elif self.loss_type == "nce":
549
- nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \
550
- 1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean()
551
- self.rolling_avg.add(f"loss/nce", nce_loss)
552
- return nce_loss
553
- else:
554
- raise ValueError(f"Unknown loss type {self.loss_type}")
555
-
556
- def loss(self, preds):
557
- image_feats = preds[IMAGE_FEATS]
558
- audio_feats = preds[AUDIO_FEATS]
559
- audio_mask = preds[AUDIO_MASK]
560
- image_mask = preds[IMAGE_MASK]
561
- audio_pos_mask = preds[AUDIO_POS_MASK]
562
- if DATA_SOURCE in preds:
563
- source = preds[DATA_SOURCE].to(torch.int64)
564
- else:
565
- source = None
566
-
567
- uncal_sims = self.sim_agg(preds, agg_heads=True)
568
- sims = self.sim_cal(uncal_sims)
569
-
570
- _mask = 1 - torch.eye(sims.shape[0], device=sims.device)
571
- self.log(f"sim/pos", torch.diag(sims).mean())
572
- self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum()))
573
- self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean())
574
- self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum()))
575
-
576
- b, c, h, w = image_feats.shape
577
- b, c, f, t = audio_feats.shape
578
- n_samples = 250
579
-
580
- nh = self.sim_agg_heads
581
- image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w)
582
- audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t)
583
-
584
- def maybe_clamp(t):
585
- return t.clamp_min(0) if self.nonneg_sim else t
586
-
587
- paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False)
588
- paired_sim = maybe_clamp(paired_sim_raw)
589
-
590
- loss = 0.0
591
-
592
- if self.nonneg_pressure:
593
- afb, afk, afc, aff, aft = audio_feats_by_head.shape
594
- ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape
595
- assert (afb == ifb)
596
-
597
- device = audio_feats_by_head.device
598
- random_b = torch.randint(0, afb, size=(n_samples,), device=device)
599
- random_t = torch.randint(0, aft, size=(n_samples,), device=device)
600
- random_f = torch.randint(0, aff, size=(n_samples,), device=device)
601
- random_h = torch.randint(0, ifh, size=(n_samples,), device=device)
602
- random_w = torch.randint(0, ifw, size=(n_samples,), device=device)
603
-
604
- random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t]
605
- random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w]
606
- random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats)
607
-
608
- nonneg_loss = random_sim_raw.clamp_max(0).square().mean()
609
- self.rolling_avg.add(f"loss/nonneg", nonneg_loss)
610
- loss += nonneg_loss * self.nonneg_pressure
611
-
612
- if self.silence_l1 > 0 or self.silence_l2 > 0:
613
- masked_b, masked_t = torch.where(~audio_mask)
614
- if len(masked_b) > n_samples:
615
- subset = torch.randperm(len(masked_b))[:n_samples]
616
- masked_b = masked_b[subset]
617
- masked_t = masked_t[subset]
618
-
619
- if len(masked_b) == n_samples:
620
- silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c
621
- silence_tensor = maybe_clamp(
622
- torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats))
623
-
624
- silence_l1_loss = silence_tensor.abs().mean()
625
- self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss)
626
- loss += silence_l1_loss * self.silence_l1
627
-
628
- silence_l2_loss = silence_tensor.square().mean()
629
- self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss)
630
- loss += silence_l2_loss * self.silence_l2
631
- else:
632
- pass
633
-
634
- if self.neg_audio_weight > 0 and self.neg_audio:
635
- b, t = audio_pos_mask.shape
636
- negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t)
637
- negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape)
638
- if negative_weight.sum() > 0:
639
- neg_audio_loss = (paired_sim.square() * negative_weight).sum() \
640
- / negative_weight.sum().clamp_min(0.1)
641
- self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss)
642
- self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean())
643
- loss += neg_audio_loss * self.neg_audio_weight
644
- else:
645
- print("WARNING: No negative samples found in batch")
646
-
647
- if self.tv_weight > 0:
648
- tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean()
649
- self.rolling_avg.add(f"loss/tv", tv_loss)
650
- loss += tv_loss * self.tv_weight
651
-
652
- self.log(f"cal/w", self.sim_cal.get_w())
653
- if self.cal_balance_weight > 0.0:
654
- cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \
655
- .clamp_min(0).square().mean()
656
- self.rolling_avg.add(f"loss/cal_balance", cal_balance)
657
- loss += cal_balance * self.cal_balance_weight
658
-
659
- if self.disentangle_weight > 0.0:
660
- assert source is not None
661
- assert self.sim_agg_heads % 2 == 0
662
-
663
- dilation = self.sim_agg_heads // 2
664
- sources_oh = F.one_hot(source, num_classes=2)
665
- b, h = sources_oh.shape
666
- sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \
667
- .reshape(b, h * dilation).to(paired_sim)
668
- disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean()
669
- self.rolling_avg.add(f"loss/disentangle", disentangle_loss)
670
- loss += disentangle_loss * self.disentangle_weight
671
-
672
- if self.specialization_weight > 0.0 and self.sim_agg_heads > 1:
673
- total_specialization_loss = 0.0
674
- combos = list(combinations(range(self.sim_agg_heads), 2))
675
- for i, j in combos:
676
- specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean()
677
- total_specialization_loss += specialization_loss_pair
678
- avg_specialization_loss = total_specialization_loss / len(combos)
679
- self.rolling_avg.add(f"loss/specialize", avg_specialization_loss)
680
- loss += avg_specialization_loss * self.specialization_weight
681
-
682
- if self.mixup_weight > 0.0:
683
- b, _, h, w = image_mask.shape
684
- neg_img_mask = torch.broadcast_to(
685
- 1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1),
686
- paired_sim.shape)
687
- image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1)
688
- self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss)
689
- loss += image_mixup_loss * self.mixup_weight
690
-
691
- sims = sims
692
- loss += self.contrast_loss(sims)
693
- self.rolling_avg.add(f"loss/total", loss)
694
-
695
- return loss
696
-
697
- def setup_hparams(self):
698
- recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10']
699
-
700
- if self.trainer.datamodule.use_extra_val_sets:
701
- datasets = ["Places", "AudioSet"]
702
- else:
703
- datasets = ["Val"]
704
-
705
- heads = ["total"]
706
-
707
- metric_names = [
708
- "hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap",
709
- "hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou",
710
- ]
711
- for dataset in datasets:
712
- for head in heads:
713
- for recall in recalls:
714
- metric_names.append(f"hp/{dataset}/{head}/{recall}")
715
-
716
- if self.sim_agg_heads == 2:
717
- metric_names.extend(["hp/ap_dis", "hp/act_dis"])
718
-
719
- if hasattr(self.trainer, "datamodule"):
720
- all_hparams = {**self.hparams, **self.trainer.datamodule.hparams}
721
- else:
722
- all_hparams = self.hparams
723
-
724
- starting_values = {n: torch.nan for n in metric_names}
725
- self.logger.log_hyperparams(all_hparams, starting_values)
726
-
727
- def on_train_start(self):
728
- self.setup_hparams()
729
- self.hparams_logged = True
730
-
731
- def on_train_batch_start(self, batch, batch_idx):
732
- remake_optimizers = False
733
-
734
- if isinstance(self.image_aligner, ProgressiveGrowing):
735
- should_remake = self.image_aligner.maybe_change_phase(self.global_step)
736
- remake_optimizers = remake_optimizers or should_remake
737
- if isinstance(self.audio_aligner, ProgressiveGrowing):
738
- should_remake = self.audio_aligner.maybe_change_phase(self.global_step)
739
- remake_optimizers = remake_optimizers or should_remake
740
-
741
- if remake_optimizers:
742
- raise NotImplementedError()
743
-
744
- def _combine_preds(self, all_preds):
745
- temp = {}
746
- new_preds = {}
747
-
748
- # Collect tensors for each key into lists
749
- for d in all_preds:
750
- for key, value in d.items():
751
- if isinstance(value, torch.Tensor):
752
- if key not in temp:
753
- temp[key] = []
754
- temp[key].append(value)
755
-
756
- # Concatenate all tensors for each key using a single call to torch.cat
757
- for key, tensor_list in temp.items():
758
- new_preds[key] = torch.cat(tensor_list)
759
- return new_preds
760
-
761
- def training_step(self, batch, batch_idx):
762
- assert batch[IMAGE_INPUT].shape[1] == 1
763
-
764
- preds = self.forward(batch)
765
- if DATA_SOURCE in batch:
766
- preds[DATA_SOURCE] = batch[DATA_SOURCE]
767
-
768
- if self.trainer.world_size > 1 and self.gather_tensors:
769
- for k, v in preds.items():
770
- new_v = v.contiguous()
771
- preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0)
772
-
773
- if self.memory_buffer_size > 0:
774
- new_preds = self._combine_preds(list(self.memory_buffer) + [preds])
775
- else:
776
- new_preds = preds
777
-
778
- loss = self.loss(new_preds)
779
-
780
- if self.memory_buffer_size > 0:
781
- self.memory_buffer.append(self._recursive_detach(preds, gather=False))
782
-
783
- if self.trainer.is_global_zero and self.global_step % 50 == 1:
784
- writer = self.logger.experiment
785
- self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step))
786
-
787
- if self.trainer.scaler is not None:
788
- self.log("loss_scale", self.trainer.scaler.get_scale())
789
-
790
- if self.global_step % 10000 == 0 and self.global_step > 0:
791
- print("RESETTING TFEVENT FILE")
792
- self.logger.experiment.close()
793
- self.logger.experiment._get_file_writer()
794
-
795
- return loss
796
-
797
- def on_validation_start(self) -> None:
798
- if not self.hparams_logged:
799
- self.setup_hparams()
800
- self.hparams_logged = True
801
-
802
- def _auto_gather(self, t):
803
- if t.dtype == torch.bool:
804
- t = t.to(torch.float)
805
-
806
- if self.trainer.num_devices == 1:
807
- return t.cpu()
808
-
809
- t = torch.clone(t).contiguous()
810
- if self.trainer.is_global_zero:
811
- gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
812
- dist.gather(t, gather_list)
813
- return torch.cat(gather_list, dim=0).cpu()
814
- else:
815
- dist.gather(t)
816
-
817
- def validation_step(self, batch, batch_idx, dataloader_idx=0):
818
-
819
- with torch.no_grad():
820
- preds = self.forward(batch)
821
-
822
- ret = {}
823
- for k in preds.keys():
824
- if k in preds:
825
- ret[k] = self._auto_gather(preds[k])
826
-
827
- batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length']
828
- for k in batch_keys:
829
- if k in batch:
830
- ret[k] = self._auto_gather(batch[k])
831
-
832
- if "metadata" in batch:
833
- if isinstance(batch["metadata"]["id"], torch.Tensor):
834
- ret["id"] = self._auto_gather(batch["metadata"]["id"])
835
- ret["index"] = self._auto_gather(batch["metadata"]["index"])
836
-
837
- return ret
838
-
839
- def _calc_recalls(self, sim):
840
- top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0)
841
- top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0)
842
- a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean()
843
- i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean()
844
- return {'A_r1': a_recall(1),
845
- 'A_r5': a_recall(5),
846
- 'A_r10': a_recall(10),
847
- 'I_r1': i_recall(1),
848
- 'I_r5': i_recall(5),
849
- 'I_r10': i_recall(10)}
850
-
851
- def calc_recalls(self, preds, dataset):
852
- sim = self.sim_agg.forward_batched(
853
- preds=preds,
854
- agg_heads=False,
855
- batch_size=4,
856
- ).cpu()
857
-
858
- all_metrics = dict()
859
- for k, v in self._calc_recalls(sim.sum(-1)).items():
860
- all_metrics[f"hp/{dataset}/total/" + k] = v
861
-
862
- return all_metrics
863
-
864
- def retrieval_validation(self, outputs, dataset_name):
865
- if len(outputs) == 0:
866
- return
867
-
868
- if self.trainer.is_global_zero:
869
- results = flatten_preds(outputs)
870
- if not self.trainer.sanity_checking:
871
- print(results[IMAGE_FEATS].shape[0])
872
- # assert (results[IMAGE_FEATS].shape[0] == 1000)
873
- results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu()
874
- results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda()
875
- if self.sim_use_cls:
876
- results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
877
- results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
878
-
879
- results[AUDIO_MASK] = results[AUDIO_MASK].cuda()
880
-
881
- recalls = self.calc_recalls(results, dataset_name)
882
-
883
- results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda()
884
-
885
- writer = self.logger.experiment
886
- print("here")
887
- for name, v in recalls.items():
888
- writer.add_scalar(f"{name}", v, self.global_step + 1)
889
-
890
- def semseg_validation(self, speech_preds, sound_preds):
891
-
892
- if self.trainer.is_global_zero:
893
- from eval_utils import get_paired_heatmaps
894
- def prep_preds(preds, loader):
895
- results = flatten_preds(preds)
896
- metadata = loader.dataset.metadata
897
- ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy()
898
- ordered_metadata["order"] = range(len(ordered_metadata))
899
- return results, ordered_metadata
900
-
901
- [_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders
902
- speech_results, speech_metadata = prep_preds(speech_preds, speech_loader)
903
- sound_results, sound_metadata = prep_preds(sound_preds, sound_loader)
904
-
905
- self.sound_metrics, unique_sound_indices = get_paired_heatmaps(
906
- self, sound_results, sound_metadata["ade_class_id"], None)
907
-
908
- self.speech_metrics, unique_word_indices = get_paired_heatmaps(
909
- self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"])
910
-
911
- writer = self.logger.experiment
912
-
913
- all_metrics = {
914
- **{"sound_" + k: v for k, v in self.sound_metrics.items()},
915
- **{"speech_" + k: v for k, v in self.speech_metrics.items()},
916
- }
917
-
918
- for k, v in all_metrics.items():
919
- writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1)
920
-
921
- def disentangle_validation(self, word_preds, sound_preds):
922
-
923
- if len(word_preds) == 0 or len(sound_preds) == 0:
924
- return
925
-
926
- if self.trainer.is_global_zero:
927
- word_preds = flatten_preds(word_preds)
928
- sound_preds = flatten_preds(sound_preds)
929
-
930
- word_scores = self.sim_agg.get_pairwise_sims(
931
- word_preds,
932
- raw=False,
933
- agg_sim=True,
934
- agg_heads=False,
935
- )
936
-
937
- sound_scores = self.sim_agg.get_pairwise_sims(
938
- sound_preds,
939
- raw=False,
940
- agg_sim=True,
941
- agg_heads=False,
942
- )
943
-
944
- all_scores = torch.cat([word_scores, sound_scores], dim=0)
945
- all_scores -= all_scores.min(dim=0, keepdim=True).values
946
- all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001)
947
-
948
- is_words = torch.cat([
949
- torch.ones(word_scores.shape[0]),
950
- torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool)
951
-
952
- assert all_scores.shape[1] == 2
953
- ap_matrix = torch.zeros(2, 2)
954
- act_matrix = torch.zeros(2, 2)
955
-
956
- for head in range(2):
957
- # writer.add_histogram(f"h{head}_all_scores", all_scores[:, head])
958
- for dataset_num in range(2):
959
- if dataset_num == 0:
960
- labels = is_words
961
- else:
962
- labels = ~is_words
963
-
964
- ap_matrix[head, dataset_num] = binary_average_precision(
965
- all_scores[:, head].cpu(), labels.to(torch.int64).cpu())
966
-
967
- act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean()
968
-
969
- ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]),
970
- .5 * (ap_matrix[0, 1] + ap_matrix[1, 0]))
971
-
972
- act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]),
973
- .5 * (act_matrix[0, 1] + act_matrix[1, 0]))
974
-
975
- print("AP", ap_matrix)
976
- print("AP dis", ap_dis)
977
- print("Act", act_matrix)
978
- print("Act dis", act_dis)
979
-
980
- writer = self.logger.experiment
981
- writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1)
982
- writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1)
983
-
984
- def validation_epoch_end(self, outputs) -> None:
985
- print("Val end")
986
- with torch.no_grad():
987
- if self.trainer.datamodule.use_extra_val_sets:
988
- if self.sim_agg_heads == 2:
989
- self.disentangle_validation(outputs[0], outputs[1])
990
- self.retrieval_validation(outputs[0], "Places")
991
- self.retrieval_validation(outputs[1], "AudioSet")
992
- self.semseg_validation(outputs[2], outputs[3])
993
-
994
- else:
995
- print("HERE!")
996
- self.retrieval_validation(outputs, "Val")
997
-
998
- writer = self.logger.experiment
999
- writer.flush()
1000
-
1001
- def _recursive_detach(self, obj, gather=True):
1002
- if isinstance(obj, torch.Tensor):
1003
- if gather:
1004
- return self._auto_gather(obj)
1005
- else:
1006
- obj.detach()
1007
- elif isinstance(obj, dict):
1008
- return {k: self._recursive_detach(v, gather) for k, v in obj.items()}
1009
- elif isinstance(obj, list):
1010
- return [self._recursive_detach(v, gather) for v in obj]
1011
- else:
1012
- return obj
1013
-
1014
- def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
1015
- with torch.no_grad():
1016
- predictions = {}
1017
- for k, v in batch.items():
1018
- predictions[k] = self._recursive_detach(v)
1019
- for k, v in self.forward(batch).items():
1020
- predictions[k] = self._auto_gather(v)
1021
-
1022
- return predictions
1023
-
1024
- def _configure_optimizers(self, full_train, lr):
1025
- params = [
1026
- *self.audio_aligner.parameters(),
1027
- *self.image_aligner.parameters(),
1028
- *self.sim_cal.parameters(),
1029
- *self.sim_agg.parameters()
1030
- ]
1031
-
1032
- if (self.finetune_image_model or self.image_lora) and full_train:
1033
- params.extend(self.image_model.parameters())
1034
-
1035
- if (self.finetune_audio_model or self.audio_lora) and full_train:
1036
- params.extend(self.audio_model.parameters())
1037
-
1038
- if self.learn_audio_cls:
1039
- params.append(self.audio_cls)
1040
-
1041
- last_epoch = self.global_step - 1
1042
- if self.optimizer == "adam":
1043
- opt = torch.optim.Adam(params, lr=lr, eps=1e-7)
1044
- elif self.optimizer == "nadam":
1045
- opt = torch.optim.NAdam(params, lr=lr, eps=1e-7)
1046
- else:
1047
- raise ValueError(f"Unknown optimizer {self.optimizer}")
1048
-
1049
- if self.lr_schedule == "sgdr":
1050
- scheduler = CosineAnnealingWarmRestarts(
1051
- opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch)
1052
- else:
1053
- scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch)
1054
-
1055
- if self.lr_warmup > 0:
1056
- warmup = LambdaLR(
1057
- opt,
1058
- lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0),
1059
- last_epoch=last_epoch,
1060
- )
1061
- scheduler = SequentialLR(
1062
- opt,
1063
- schedulers=[warmup, scheduler],
1064
- milestones=[self.lr_warmup],
1065
- last_epoch=last_epoch)
1066
-
1067
- scheduler = {"scheduler": scheduler, "interval": "step"}
1068
-
1069
- return [opt], [scheduler]
1070
-
1071
- def configure_optimizers(self):
1072
- if self.full_train:
1073
- return self._configure_optimizers(self.full_train, self.lr)
1074
- else:
1075
- return self._configure_optimizers(self.full_train, self.pretrain_lr)
1076
-
1077
-
1078
- @hydra.main(config_path="configs", config_name="av_align.yaml", version_base=None)
1079
- def my_app(cfg: DictConfig) -> None:
1080
- print(OmegaConf.to_yaml(cfg))
1081
- seed_everything(cfg.seed, workers=True)
1082
-
1083
- exp_name = f"{cfg.resume_prefix}"
1084
-
1085
- if cfg.image_model_type == "dino8":
1086
- patch_size = 8 * cfg.image_pool_width
1087
- elif cfg.image_model_type == "cavmae":
1088
- patch_size = 16 * cfg.image_pool_width
1089
- elif cfg.image_model_type == "imagebind":
1090
- patch_size = 16 * cfg.image_pool_width
1091
- elif cfg.image_model_type == "clip":
1092
- patch_size = 16 * cfg.image_pool_width
1093
- elif cfg.image_model_type == "cavmae-mixed":
1094
- patch_size = 16 * cfg.image_pool_width
1095
- elif cfg.image_model_type == "dinov2":
1096
- patch_size = 14 * cfg.image_pool_width
1097
- else:
1098
- raise ValueError(f"Unknown patch size for model {cfg.image_model_type}")
1099
-
1100
- datamodule = AVDataModule(
1101
- dataset_name=cfg.dataset_name,
1102
- load_size=cfg.load_size,
1103
- image_aug=cfg.image_aug,
1104
- audio_aug=cfg.audio_aug,
1105
- extra_audio_masking=cfg.extra_audio_masking,
1106
- audio_model_type=cfg.audio_model_type,
1107
- pytorch_data_dir=cfg.pytorch_data_dir,
1108
- use_cached_embs=cfg.use_cached_embs,
1109
- batch_size=cfg.batch_size,
1110
- num_workers=cfg.num_workers,
1111
- audio_level=cfg.audio_level,
1112
- neg_audio=cfg.neg_audio,
1113
- use_original_val_set=not cfg.use_extra_val_sets,
1114
- use_extra_val_sets=cfg.use_extra_val_sets,
1115
- data_for_plotting=False,
1116
- quad_mixup=cfg.quad_mixup,
1117
- bg_mixup=cfg.bg_mixup,
1118
- patch_mixup=cfg.patch_mixup,
1119
- patch_size=patch_size
1120
- )
1121
- datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml)
1122
-
1123
- aligner = create_model_from_cfg(LitAVAligner, cfg, {})
1124
-
1125
- if cfg.starting_weights is not None:
1126
- loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu')
1127
- state = loaded["state_dict"]
1128
- aligner.load_state_dict(state, strict=cfg.load_strict)
1129
- del state
1130
- del loaded
1131
-
1132
- if cfg.num_gpus > 1:
1133
- # strategy = "ddp_sharded" # _find_unused_parameters_true"
1134
- strategy = "ddp" # _find_unused_parameters_true"
1135
- else:
1136
- strategy = "auto"
1137
-
1138
- if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}:
1139
- val_args = dict(check_val_every_n_epoch=2)
1140
- elif cfg.dataset_name in {"dolphin"}:
1141
- val_args = dict(check_val_every_n_epoch=5)
1142
- else:
1143
- val_args = dict(val_check_interval=10000)
1144
-
1145
- # val_args = dict(val_check_interval=1000)
1146
-
1147
- def maybe_get_ckpt(ckpt_dir):
1148
- if cfg.auto_resume and os.path.exists(ckpt_dir):
1149
- print(f"Attempting to resume from {ckpt_dir}")
1150
- candidates = os.listdir(ckpt_dir)
1151
- assert (len(candidates) == 1)
1152
- return join(ckpt_dir, candidates[0])
1153
- elif cfg.auto_resume:
1154
- print(f"Could not find checkpoint at {ckpt_dir}")
1155
- return None
1156
- else:
1157
- return None
1158
-
1159
- log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name)
1160
- ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name)
1161
-
1162
- import gc
1163
- torch.cuda.empty_cache()
1164
- gc.collect()
1165
-
1166
- def run_exp(aligner, full_train):
1167
- trainer_args = dict(
1168
- accelerator='gpu',
1169
- strategy=strategy,
1170
- devices=cfg.num_gpus,
1171
- num_sanity_val_steps=cfg.num_sanity_val_steps,
1172
- log_every_n_steps=50,
1173
- reload_dataloaders_every_n_epochs=10,
1174
- precision="16",
1175
- # profiler="simple",
1176
- # precision="bf16",
1177
- max_steps=cfg.max_steps,
1178
- **val_args)
1179
-
1180
- aligner.set_full_train(full_train)
1181
- if full_train:
1182
- suffix = "train"
1183
- else:
1184
- suffix = "pretrain"
1185
- trainer_args["max_steps"] = cfg.pretrain_steps
1186
-
1187
- print(f"Starting {suffix} phase")
1188
-
1189
- logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False)
1190
- callbacks = [
1191
- ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1),
1192
- LearningRateMonitor(logging_interval='step'),
1193
- ]
1194
- Trainer(logger=logger,
1195
- callbacks=callbacks,
1196
- **trainer_args).fit(
1197
- aligner,
1198
- datamodule=datamodule,
1199
- ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix)))
1200
-
1201
- train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train"))
1202
-
1203
- gc.collect()
1204
- if torch.cuda.is_available():
1205
- torch.cuda.empty_cache()
1206
-
1207
- if cfg.pretrain_steps > 0 and train_chkpt is None:
1208
- run_exp(aligner, full_train=False)
1209
- run_exp(aligner, full_train=True)
1210
-
1211
-
1212
- if __name__ == "__main__":
1213
- my_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/gradio_app.py DELETED
@@ -1,196 +0,0 @@
1
- import csv
2
- import os
3
- import tempfile
4
-
5
- import gradio as gr
6
- import requests
7
- import torch
8
- import torchvision
9
- import torchvision.transforms as T
10
- from PIL import Image
11
- from featup.util import norm
12
- from torchaudio.functional import resample
13
-
14
- from denseav.train import LitAVAligner
15
- from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
16
- from denseav.shared import norm, crop_to_divisor, blur_dim
17
- from os.path import join
18
-
19
- if __name__ == "__main__":
20
-
21
- mode = "local"
22
-
23
- if mode == "local":
24
- sample_videos_dir = "samples"
25
- else:
26
- os.environ['TORCH_HOME'] = '/tmp/.cache'
27
- os.environ['HF_HOME'] = '/tmp/.cache'
28
- os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
29
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
30
- os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
31
- sample_videos_dir = "/tmp/samples"
32
-
33
-
34
- def download_video(url, save_path):
35
- response = requests.get(url)
36
- with open(save_path, 'wb') as file:
37
- file.write(response.content)
38
-
39
-
40
- base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
41
- sample_videos_urls = {
42
- "puppies.mp4": base_url + "puppies.mp4",
43
- "peppers.mp4": base_url + "peppers.mp4",
44
- "boat.mp4": base_url + "boat.mp4",
45
- "elephant2.mp4": base_url + "elephant2.mp4",
46
-
47
- }
48
-
49
- # Ensure the directory for sample videos exists
50
- os.makedirs(sample_videos_dir, exist_ok=True)
51
-
52
- # Download each sample video
53
- for filename, url in sample_videos_urls.items():
54
- save_path = os.path.join(sample_videos_dir, filename)
55
- # Download the video if it doesn't already exist
56
- if not os.path.exists(save_path):
57
- print(f"Downloading {filename}...")
58
- download_video(url, save_path)
59
- else:
60
- print(f"{filename} already exists. Skipping download.")
61
-
62
- csv.field_size_limit(100000000)
63
- options = ['language', "sound-language", "sound"]
64
- load_size = 224
65
- plot_size = 224
66
-
67
- video_input = gr.Video(label="Choose a video to featurize", height=480)
68
- model_option = gr.Radio(options, value="language", label='Choose a model')
69
-
70
- video_output1 = gr.Video(label="Audio Video Attention", height=480)
71
- video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
72
- height=480)
73
- video_output3 = gr.Video(label="Visual Features", height=480)
74
-
75
- models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
76
-
77
-
78
- def process_video(video, model_option):
79
- model = models[model_option].cuda()
80
-
81
- original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
82
- sample_rate = 16000
83
-
84
- if info["audio_fps"] != sample_rate:
85
- audio = resample(audio, info["audio_fps"], sample_rate)
86
- audio = audio[0].unsqueeze(0)
87
-
88
- img_transform = T.Compose([
89
- T.Resize(load_size, Image.BILINEAR),
90
- lambda x: crop_to_divisor(x, 8),
91
- lambda x: x.to(torch.float32) / 255,
92
- norm])
93
-
94
- frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
95
-
96
- plotting_img_transform = T.Compose([
97
- T.Resize(plot_size, Image.BILINEAR),
98
- lambda x: crop_to_divisor(x, 8),
99
- lambda x: x.to(torch.float32) / 255])
100
-
101
- frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
102
-
103
- with torch.no_grad():
104
- audio_feats = model.forward_audio({"audio": audio.cuda()})
105
- audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
106
- image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
107
- image_feats = {k: v.cpu() for k, v in image_feats.items()}
108
-
109
- sim_by_head = model.sim_agg.get_pairwise_sims(
110
- {**image_feats, **audio_feats},
111
- raw=False,
112
- agg_sim=False,
113
- agg_heads=False
114
- ).mean(dim=-2).cpu()
115
-
116
- sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
117
- print(sim_by_head.shape)
118
-
119
- temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
120
-
121
- plot_attention_video(
122
- sim_by_head,
123
- frames_to_plot,
124
- audio,
125
- info["video_fps"],
126
- sample_rate,
127
- temp_video_path_1)
128
-
129
- if model_option == "sound_and_language":
130
- temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
131
-
132
- plot_2head_attention_video(
133
- sim_by_head,
134
- frames_to_plot,
135
- audio,
136
- info["video_fps"],
137
- sample_rate,
138
- temp_video_path_2)
139
-
140
- else:
141
- temp_video_path_2 = None
142
-
143
- temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
144
- temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
145
-
146
- plot_feature_video(
147
- image_feats["image_feats"].cpu(),
148
- audio_feats['audio_feats'].cpu(),
149
- frames_to_plot,
150
- audio,
151
- info["video_fps"],
152
- sample_rate,
153
- temp_video_path_3,
154
- temp_video_path_4,
155
- )
156
- # return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
157
-
158
- return temp_video_path_1, temp_video_path_2, temp_video_path_3
159
-
160
-
161
- with gr.Blocks() as demo:
162
- with gr.Column():
163
- gr.Markdown("## Visualizing Sound and Language with DenseAV")
164
- gr.Markdown(
165
- "This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.")
166
- with gr.Row():
167
- with gr.Column(scale=1):
168
- model_option.render()
169
- with gr.Column(scale=3):
170
- video_input.render()
171
- with gr.Row():
172
- submit_button = gr.Button("Submit")
173
- with gr.Row():
174
- gr.Examples(
175
- examples=[
176
- [join(sample_videos_dir, "puppies.mp4"), "sound_and_language"],
177
- [join(sample_videos_dir, "peppers.mp4"), "language"],
178
- [join(sample_videos_dir, "elephant2.mp4"), "language"],
179
- [join(sample_videos_dir, "boat.mp4"), "language"]
180
-
181
- ],
182
- inputs=[video_input, model_option]
183
- )
184
- with gr.Row():
185
- video_output1.render()
186
- video_output2.render()
187
- video_output3.render()
188
-
189
- submit_button.click(fn=process_video, inputs=[video_input, model_option],
190
- outputs=[video_output1, video_output2, video_output3])
191
-
192
-
193
- if mode == "local":
194
- demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
195
- else:
196
- demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/hubconf.py DELETED
@@ -1,25 +0,0 @@
1
- # hubconf.py
2
- from denseav.train import LitAVAligner
3
-
4
- dependencies = ['torch', 'torchvision', 'PIL', 'denseav'] # List any dependencies here
5
-
6
-
7
- def _load_base(model_name):
8
- model = LitAVAligner.load_from_checkpoint(
9
- f"https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/{model_name}.ckpt",
10
- **{'loss_leak': 0.0, 'use_cached_embs': False},
11
- strict=True)
12
- model.set_full_train(True)
13
- return model
14
-
15
-
16
- def sound_and_language():
17
- return _load_base("denseav_2head")
18
-
19
-
20
- def language():
21
- return _load_base("denseav_language")
22
-
23
-
24
- def sound():
25
- return _load_base("denseav_sound")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
DenseAV/samples/puppies.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d4bc5049010142b9a4364afea7da15d4e9736d95cfc9a365c2658c69ba409d56
3
- size 7534432
 
 
 
 
DenseAV/setup.py DELETED
@@ -1,37 +0,0 @@
1
- from setuptools import setup, find_packages
2
-
3
- setup(
4
- name='denseav',
5
- version='0.1.0',
6
- packages=find_packages(),
7
- install_requires=[
8
- 'torch',
9
- 'kornia',
10
- 'omegaconf',
11
- 'pytorch-lightning',
12
- 'torchvision',
13
- 'tqdm',
14
- 'torchmetrics',
15
- 'scikit-learn',
16
- 'numpy',
17
- 'matplotlib',
18
- 'timm==0.4.12',
19
- 'moviepy',
20
- 'hydra-core',
21
- 'peft==0.5.0',
22
- 'av',
23
- 'audioread'
24
- ],
25
- author='Mark Hamilton',
26
- author_email='[email protected]',
27
- description='Offical code for the CVPR 2024 Paper: Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language',
28
- long_description=open('README.md').read(),
29
- long_description_content_type='text/markdown',
30
- url='https://github.com/mhamilton723/DenseAV',
31
- classifiers=[
32
- 'Programming Language :: Python :: 3',
33
- 'License :: OSI Approved :: MIT License',
34
- 'Operating System :: OS Independent',
35
- ],
36
- python_requires='>=3.6'
37
- )