Spaces:
No application file
No application file
Upload 32 files
Browse files- DenseAV/.gitignore +5 -0
- DenseAV/LICENSE +22 -0
- DenseAV/README.md +172 -0
- DenseAV/__init__.py +0 -0
- DenseAV/demo.ipynb +0 -0
- DenseAV/denseav/__init__.py +0 -0
- DenseAV/denseav/aggregators.py +517 -0
- DenseAV/denseav/aligners.py +300 -0
- DenseAV/denseav/configs/av_align.yaml +125 -0
- DenseAV/denseav/constants.py +12 -0
- DenseAV/denseav/data/AVDatasets.py +1249 -0
- DenseAV/denseav/data/__init__.py +0 -0
- DenseAV/denseav/data/make_tarballs.py +108 -0
- DenseAV/denseav/eval_utils.py +135 -0
- DenseAV/denseav/evaluate.py +87 -0
- DenseAV/denseav/featurizers/AudioMAE.py +570 -0
- DenseAV/denseav/featurizers/CAVMAE.py +1082 -0
- DenseAV/denseav/featurizers/CLIP.py +50 -0
- DenseAV/denseav/featurizers/DAVENet.py +162 -0
- DenseAV/denseav/featurizers/DINO.py +451 -0
- DenseAV/denseav/featurizers/DINOv2.py +49 -0
- DenseAV/denseav/featurizers/Hubert.py +70 -0
- DenseAV/denseav/featurizers/ImageBind.py +2033 -0
- DenseAV/denseav/featurizers/__init__.py +0 -0
- DenseAV/denseav/plotting.py +244 -0
- DenseAV/denseav/saved_models.py +262 -0
- DenseAV/denseav/shared.py +739 -0
- DenseAV/denseav/train.py +1222 -0
- DenseAV/gradio_app.py +196 -0
- DenseAV/hubconf.py +25 -0
- DenseAV/samples/puppies.mp4 +3 -0
- DenseAV/setup.py +37 -0
DenseAV/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Created by .ignore support plugin (hsz.mobi)
|
2 |
+
results/attention/*
|
3 |
+
results/features/*
|
4 |
+
|
5 |
+
.env
|
DenseAV/LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) Mark Hamilton. All rights reserved.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a
|
6 |
+
copy of this software and associated documentation files (the
|
7 |
+
"Software"), to deal in the Software without restriction, including
|
8 |
+
without limitation the rights to use, copy, modify, merge, publish,
|
9 |
+
distribute, sublicense, and/or sell copies of the Software, and to
|
10 |
+
permit persons to whom the Software is furnished to do so, subject to
|
11 |
+
the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included
|
14 |
+
in all copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
17 |
+
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
18 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
19 |
+
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
20 |
+
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
21 |
+
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
22 |
+
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
DenseAV/README.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language
|
2 |
+
### CVPR 2024
|
3 |
+
|
4 |
+
|
5 |
+
[](https://aka.ms/denseav) [](https://arxiv.org/abs/2406.05629) [](https://colab.research.google.com/github/mhamilton723/DenseAV/blob/main/demo.ipynb)
|
6 |
+
|
7 |
+
[](https://huggingface.co/spaces/mhamilton723/DenseAV)
|
8 |
+
|
9 |
+
[//]: # ([](https://huggingface.co/papers/2403.10516))
|
10 |
+
[](https://paperswithcode.com/sota/speech-prompted-semantic-segmentation-on?p=separating-the-chirp-from-the-chat-self)
|
11 |
+
[](https://paperswithcode.com/sota/sound-prompted-semantic-segmentation-on?p=separating-the-chirp-from-the-chat-self)
|
12 |
+
|
13 |
+
|
14 |
+
[Mark Hamilton](https://mhamilton.net/),
|
15 |
+
[Andrew Zisserman](https://www.robots.ox.ac.uk/~az/),
|
16 |
+
[John R. Hershey](https://research.google/people/john-hershey/),
|
17 |
+
[William T. Freeman](https://billf.mit.edu/about/bio)
|
18 |
+
|
19 |
+

|
20 |
+
|
21 |
+
**TL;DR**:Our model, DenseAV, learns the meaning of words and the location of sounds (visual grounding) without supervision or text.
|
22 |
+
|
23 |
+
https://github.com/mhamilton723/DenseAV/assets/6456637/ba908ab5-9618-42f9-8d7a-30ecb009091f
|
24 |
+
|
25 |
+
|
26 |
+
## Contents
|
27 |
+
<!--ts-->
|
28 |
+
* [Install](#install)
|
29 |
+
* [Model Zoo](#model-zoo)
|
30 |
+
* [Getting Datasets](#getting-atasets)
|
31 |
+
* [Evaluate Models](#evaluate-models)
|
32 |
+
* [Train a Model](#train-model)
|
33 |
+
* [Local Gradio Demo](#local-gradio-demo)
|
34 |
+
* [Coming Soon](coming-soon)
|
35 |
+
* [Citation](#citation)
|
36 |
+
* [Contact](#contact)
|
37 |
+
<!--te-->
|
38 |
+
|
39 |
+
## Install
|
40 |
+
|
41 |
+
To use DenseAV locally clone the repository:
|
42 |
+
|
43 |
+
```shell script
|
44 |
+
git clone https://github.com/mhamilton723/DenseAV.git
|
45 |
+
cd DenseAV
|
46 |
+
pip install -e .
|
47 |
+
```
|
48 |
+
|
49 |
+
|
50 |
+
## Model Zoo
|
51 |
+
|
52 |
+
To see examples of pretrained model usage please see our [Collab notebook](https://colab.research.google.com/github/mhamilton723/DenseAV/blob/main/demo.ipynb). We currently supply the following pretrained models:
|
53 |
+
|
54 |
+
| Model Name | Checkpoint | Torch Hub Repository | Torch Hub Name |
|
55 |
+
|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------|----------------------|--------------------|
|
56 |
+
| Sound | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_sound.ckpt) | mhamilton723/DenseAV | sound |
|
57 |
+
| Language | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_language.ckpt) | mhamilton723/DenseAV | language |
|
58 |
+
| Sound + Language (Two Headed) | [Download](https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_2head.ckpt) | mhamilton723/DenseAV | sound_and_language |
|
59 |
+
|
60 |
+
For example, to load the model trained on both sound and language:
|
61 |
+
|
62 |
+
```python
|
63 |
+
model = torch.hub.load("mhamilton723/DenseAV", 'sound_and_language')
|
64 |
+
```
|
65 |
+
|
66 |
+
### Load from HuggingFace
|
67 |
+
|
68 |
+
```python
|
69 |
+
from denseav.train import LitAVAligner
|
70 |
+
|
71 |
+
model1 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-sound")
|
72 |
+
model2 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-language")
|
73 |
+
model3 = LitAVAligner.from_pretrained("mhamilton723/DenseAV-sound-language")
|
74 |
+
```
|
75 |
+
|
76 |
+
|
77 |
+
## Getting Datasets
|
78 |
+
|
79 |
+
Our code assumes that all data lives in a common directory on your system, in these examples we use `/path/to/your/data`. Our code will often reference this directory as the `data_root`
|
80 |
+
|
81 |
+
### Speech and Sound Prompted ADE20K
|
82 |
+
|
83 |
+
To download our new Speech and Sound prompted ADE20K Dataset:
|
84 |
+
|
85 |
+
```bash
|
86 |
+
cd /path/to/your/data
|
87 |
+
wget https://marhamilresearch4.blob.core.windows.net/denseav-public/datasets/ADE20KSoundPrompted.zip
|
88 |
+
unzip ADE20KSoundPrompted.zip
|
89 |
+
wget https://marhamilresearch4.blob.core.windows.net/denseav-public/datasets/ADE20KSpeechPrompted.zip
|
90 |
+
unzip ADE20KSpeechPrompted.zip
|
91 |
+
```
|
92 |
+
|
93 |
+
### Places Audio
|
94 |
+
|
95 |
+
First download the places audio dataset from its [original source](https://groups.csail.mit.edu/sls/downloads/placesaudio/downloads.cgi).
|
96 |
+
|
97 |
+
To run the code the data will need to be processed to be of the form:
|
98 |
+
|
99 |
+
```
|
100 |
+
[Instructions coming soon]
|
101 |
+
```
|
102 |
+
|
103 |
+
### Audioset
|
104 |
+
|
105 |
+
Because of copyright issues we cannot make [Audioset](https://research.google.com/audioset/dataset/index.html) easily availible to download.
|
106 |
+
First download this dataset through appropriate means. [This other project](https://github.com/ktonal/audioset-downloader) appears to make this simple.
|
107 |
+
|
108 |
+
To run the code the data will need to be processed to be of the form:
|
109 |
+
|
110 |
+
```
|
111 |
+
[Instructions coming soon]
|
112 |
+
```
|
113 |
+
|
114 |
+
|
115 |
+
## Evaluate Models
|
116 |
+
|
117 |
+
To evaluate a trained model first clone the repository for
|
118 |
+
[local development](#local-development). Then run
|
119 |
+
|
120 |
+
```shell
|
121 |
+
cd featup
|
122 |
+
python evaluate.py
|
123 |
+
```
|
124 |
+
|
125 |
+
After evaluation, see the results in tensorboard's hparams tab.
|
126 |
+
|
127 |
+
```shell
|
128 |
+
cd ../logs/evaluate
|
129 |
+
tensorboard --logdir .
|
130 |
+
```
|
131 |
+
|
132 |
+
Then visit [https://localhost:6006](https://localhost:6006) and click on hparams to browse results. We report "advanced" speech metrics and "basic" sound metrics in our paper.
|
133 |
+
|
134 |
+
|
135 |
+
## Train a Model
|
136 |
+
|
137 |
+
```shell
|
138 |
+
cd denseav
|
139 |
+
python train.py
|
140 |
+
```
|
141 |
+
|
142 |
+
## Local Gradio Demo
|
143 |
+
|
144 |
+
To run our [HuggingFace Spaces hosted DenseAV demo](https://huggingface.co/spaces/mhamilton723/FeatUp) locally first install DenseAV for local development. Then run:
|
145 |
+
|
146 |
+
```shell
|
147 |
+
python gradio_app.py
|
148 |
+
```
|
149 |
+
|
150 |
+
Wait a few seconds for the demo to spin up, then navigate to [http://localhost:7860/](http://localhost:7860/) to view the demo.
|
151 |
+
|
152 |
+
|
153 |
+
## Coming Soon:
|
154 |
+
|
155 |
+
- Bigger models!
|
156 |
+
|
157 |
+
## Citation
|
158 |
+
|
159 |
+
```
|
160 |
+
@misc{hamilton2024separating,
|
161 |
+
title={Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language},
|
162 |
+
author={Mark Hamilton and Andrew Zisserman and John R. Hershey and William T. Freeman},
|
163 |
+
year={2024},
|
164 |
+
eprint={2406.05629},
|
165 |
+
archivePrefix={arXiv},
|
166 |
+
primaryClass={cs.CV}
|
167 |
+
}
|
168 |
+
```
|
169 |
+
|
170 |
+
## Contact
|
171 |
+
|
172 |
+
For feedback, questions, or press inquiries please contact [Mark Hamilton](mailto:[email protected])
|
DenseAV/__init__.py
ADDED
File without changes
|
DenseAV/demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
DenseAV/denseav/__init__.py
ADDED
File without changes
|
DenseAV/denseav/aggregators.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
from denseav.constants import *
|
10 |
+
|
11 |
+
|
12 |
+
@torch.jit.script
|
13 |
+
def masked_mean(x: torch.Tensor, mask: torch.Tensor, dim: int):
|
14 |
+
mask = mask.to(x)
|
15 |
+
return (x * mask).sum(dim, keepdim=True) / mask.sum(dim, keepdim=True).clamp_min(.001)
|
16 |
+
|
17 |
+
|
18 |
+
@torch.jit.script
|
19 |
+
def masked_max(x: torch.Tensor, mask: torch.Tensor, dim: int):
|
20 |
+
mask = mask.to(torch.bool)
|
21 |
+
eps = 1e7
|
22 |
+
# eps = torch.finfo(x.dtype).max
|
23 |
+
return (x - (~mask) * eps).max(dim, keepdim=True).values
|
24 |
+
|
25 |
+
|
26 |
+
def masked_lse(x: torch.Tensor, mask: torch.Tensor, dim: int, temp):
|
27 |
+
x = x.to(torch.float32)
|
28 |
+
mask = mask.to(torch.float32)
|
29 |
+
x_masked = (x - (1 - mask) * torch.finfo(x.dtype).max)
|
30 |
+
return (torch.logsumexp(x_masked * temp, dim, keepdim=True) - torch.log(mask.sum(dim, keepdim=True))) / temp
|
31 |
+
|
32 |
+
|
33 |
+
class BaseAggregator(torch.nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
self.nonneg_sim = nonneg_sim
|
39 |
+
self.mask_silence = mask_silence
|
40 |
+
self.num_heads = num_heads
|
41 |
+
self.head_agg = head_agg
|
42 |
+
self.use_cls = use_cls
|
43 |
+
|
44 |
+
@abstractmethod
|
45 |
+
def _agg_sim(self, sim, mask):
|
46 |
+
pass
|
47 |
+
|
48 |
+
def prepare_sims(self, sim, mask, agg_sim, agg_heads):
|
49 |
+
sim_size = sim.shape
|
50 |
+
assert len(mask.shape) == 2
|
51 |
+
assert len(sim_size) in {6, 7}, f"sim has wrong number of dimensions: {sim.shape}"
|
52 |
+
pairwise = len(sim_size) == 6
|
53 |
+
|
54 |
+
if self.mask_silence:
|
55 |
+
mask = mask
|
56 |
+
else:
|
57 |
+
mask = torch.ones_like(mask)
|
58 |
+
|
59 |
+
if self.nonneg_sim:
|
60 |
+
sim = sim.clamp_min(0)
|
61 |
+
|
62 |
+
if pairwise:
|
63 |
+
head_dim = 1
|
64 |
+
else:
|
65 |
+
head_dim = 2
|
66 |
+
|
67 |
+
if self.head_agg == "max_elementwise" and agg_heads:
|
68 |
+
sim = sim.max(head_dim, keepdim=True).values
|
69 |
+
|
70 |
+
if agg_sim:
|
71 |
+
sim = self._agg_sim(sim, mask)
|
72 |
+
|
73 |
+
if agg_heads:
|
74 |
+
if self.head_agg == "sum" or self.head_agg == "max_elementwise":
|
75 |
+
sim = sim.sum(head_dim)
|
76 |
+
elif self.head_agg == "max":
|
77 |
+
sim = sim.max(head_dim).values
|
78 |
+
else:
|
79 |
+
raise ValueError(f"Unknown head_agg: {self.head_agg}")
|
80 |
+
|
81 |
+
return sim
|
82 |
+
|
83 |
+
def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
|
84 |
+
if agg_sim or agg_heads or raw:
|
85 |
+
assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
|
86 |
+
|
87 |
+
audio_feats = preds[AUDIO_FEATS]
|
88 |
+
audio_mask = preds[AUDIO_MASK]
|
89 |
+
image_feats = preds[IMAGE_FEATS]
|
90 |
+
|
91 |
+
b1, c2, f, t1 = audio_feats.shape
|
92 |
+
b2, t2 = audio_mask.shape
|
93 |
+
d, c1, h, w = image_feats.shape
|
94 |
+
assert b1 == b2 and c1 == c2 and t1 == t2
|
95 |
+
assert c1 % self.num_heads == 0
|
96 |
+
new_c = c1 // self.num_heads
|
97 |
+
audio_feats = audio_feats.reshape(b1, self.num_heads, new_c, f, t1)
|
98 |
+
image_feats = image_feats.reshape(d, self.num_heads, new_c, h, w)
|
99 |
+
raw_sims = torch.einsum(
|
100 |
+
"akcft,vkchw->avkhwft",
|
101 |
+
audio_feats.to(torch.float32),
|
102 |
+
image_feats.to(torch.float32))
|
103 |
+
|
104 |
+
if self.use_cls:
|
105 |
+
audio_cls = preds[AUDIO_CLS].reshape(b1, self.num_heads, new_c)
|
106 |
+
image_cls = preds[IMAGE_CLS].reshape(d, self.num_heads, new_c)
|
107 |
+
cls_sims = torch.einsum(
|
108 |
+
"akc,vkc->avk",
|
109 |
+
audio_cls.to(torch.float32),
|
110 |
+
image_cls.to(torch.float32))
|
111 |
+
raw_sims += cls_sims.reshape(b1, d, self.num_heads, 1, 1, 1, 1)
|
112 |
+
|
113 |
+
if raw:
|
114 |
+
return raw_sims
|
115 |
+
else:
|
116 |
+
return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
|
117 |
+
|
118 |
+
def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
|
119 |
+
if agg_sim or agg_heads or raw:
|
120 |
+
assert (agg_sim or agg_heads) != raw, "Cannot have raw on at the same time as agg_sim or agg_heads"
|
121 |
+
|
122 |
+
audio_feats = preds[AUDIO_FEATS]
|
123 |
+
audio_mask = preds[AUDIO_MASK]
|
124 |
+
image_feats = preds[IMAGE_FEATS]
|
125 |
+
|
126 |
+
a1, c1, f, t1 = audio_feats.shape
|
127 |
+
a2, t2 = audio_mask.shape
|
128 |
+
|
129 |
+
assert c1 % self.num_heads == 0
|
130 |
+
new_c = c1 // self.num_heads
|
131 |
+
audio_feats = audio_feats.reshape(a1, self.num_heads, new_c, f, t1)
|
132 |
+
|
133 |
+
if len(image_feats.shape) == 5:
|
134 |
+
print("Using similarity for video, should only be called during plotting")
|
135 |
+
v, vt, c2, h, w = image_feats.shape
|
136 |
+
image_feats = image_feats.reshape(v, vt, self.num_heads, new_c, h, w)
|
137 |
+
raw_sims = torch.einsum(
|
138 |
+
"bkcft,bskchw,bt->bskhwft",
|
139 |
+
audio_feats.to(torch.float32),
|
140 |
+
image_feats.to(torch.float32),
|
141 |
+
audio_mask.to(torch.float32))
|
142 |
+
|
143 |
+
if self.use_cls:
|
144 |
+
audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
|
145 |
+
image_cls = preds[IMAGE_CLS].reshape(v, vt, self.num_heads, new_c)
|
146 |
+
cls_sims = torch.einsum(
|
147 |
+
"bkc,bskc->bsk",
|
148 |
+
audio_cls.to(torch.float32),
|
149 |
+
image_cls.to(torch.float32))
|
150 |
+
raw_sims += cls_sims.reshape(v, vt, self.num_heads, 1, 1, 1, 1)
|
151 |
+
|
152 |
+
|
153 |
+
elif len(image_feats.shape) == 4:
|
154 |
+
v, c2, h, w = image_feats.shape
|
155 |
+
image_feats = image_feats.reshape(v, self.num_heads, new_c, h, w)
|
156 |
+
raw_sims = torch.einsum(
|
157 |
+
"bkcft,bkchw,bt->bkhwft",
|
158 |
+
audio_feats.to(torch.float32),
|
159 |
+
image_feats.to(torch.float32),
|
160 |
+
audio_mask.to(torch.float32))
|
161 |
+
|
162 |
+
if self.use_cls:
|
163 |
+
audio_cls = preds[AUDIO_CLS].reshape(v, self.num_heads, new_c)
|
164 |
+
image_cls = preds[IMAGE_CLS].reshape(v, self.num_heads, new_c)
|
165 |
+
cls_sims = torch.einsum(
|
166 |
+
"bkc,bkc->bk",
|
167 |
+
audio_cls.to(torch.float32),
|
168 |
+
image_cls.to(torch.float32))
|
169 |
+
raw_sims += cls_sims.reshape(v, self.num_heads, 1, 1, 1, 1)
|
170 |
+
else:
|
171 |
+
raise ValueError(f"Improper image shape: {image_feats.shape}")
|
172 |
+
|
173 |
+
assert a1 == a2 and c2 == c2 and t1 == t2
|
174 |
+
|
175 |
+
if raw:
|
176 |
+
return raw_sims
|
177 |
+
else:
|
178 |
+
return self.prepare_sims(raw_sims, audio_mask, agg_sim, agg_heads)
|
179 |
+
|
180 |
+
def forward(self, preds, agg_heads):
|
181 |
+
return self._get_full_sims(
|
182 |
+
preds, raw=False, agg_sim=True, agg_heads=agg_heads)
|
183 |
+
|
184 |
+
def forward_batched(self, preds, agg_heads, batch_size):
|
185 |
+
new_preds = {k: v for k, v in preds.items()}
|
186 |
+
big_image_feats = new_preds.pop(IMAGE_FEATS)
|
187 |
+
if self.use_cls:
|
188 |
+
big_image_cls = new_preds.pop(IMAGE_CLS)
|
189 |
+
|
190 |
+
n = big_image_feats.shape[0]
|
191 |
+
n_steps = math.ceil(n / batch_size)
|
192 |
+
outputs = []
|
193 |
+
for step in tqdm(range(n_steps), "Calculating Sim", leave=False):
|
194 |
+
new_preds[IMAGE_FEATS] = big_image_feats[step * batch_size:(step + 1) * batch_size].cuda()
|
195 |
+
if self.use_cls:
|
196 |
+
new_preds[IMAGE_CLS] = big_image_cls[step * batch_size:(step + 1) * batch_size].cuda()
|
197 |
+
|
198 |
+
sim = self.forward(new_preds, agg_heads=agg_heads)
|
199 |
+
outputs.append(sim.cpu())
|
200 |
+
return torch.cat(outputs, dim=1)
|
201 |
+
|
202 |
+
|
203 |
+
class ImageThenAudioAggregator(BaseAggregator):
|
204 |
+
|
205 |
+
def __init__(self, image_agg_type, audio_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
|
206 |
+
super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
|
207 |
+
if image_agg_type == "max":
|
208 |
+
self.image_agg = lambda x, dim: x.max(dim=dim, keepdim=True).values
|
209 |
+
elif image_agg_type == "avg":
|
210 |
+
self.image_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
|
211 |
+
else:
|
212 |
+
raise ValueError(f"Unknown image_agg_type {image_agg_type}")
|
213 |
+
|
214 |
+
if audio_agg_type == "max":
|
215 |
+
self.time_agg = masked_max
|
216 |
+
elif audio_agg_type == "avg":
|
217 |
+
self.time_agg = masked_mean
|
218 |
+
else:
|
219 |
+
raise ValueError(f"Unknown audio_agg_type {audio_agg_type}")
|
220 |
+
|
221 |
+
self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
|
222 |
+
|
223 |
+
def _agg_sim(self, sim, mask):
|
224 |
+
sim_shape = sim.shape
|
225 |
+
new_mask_shape = [1] * len(sim_shape)
|
226 |
+
new_mask_shape[0] = sim_shape[0]
|
227 |
+
new_mask_shape[-1] = sim_shape[-1]
|
228 |
+
mask = mask.reshape(new_mask_shape)
|
229 |
+
sim = self.image_agg(sim, -3)
|
230 |
+
sim = self.image_agg(sim, -4)
|
231 |
+
sim = self.freq_agg(sim, -2)
|
232 |
+
sim = self.time_agg(sim, mask, -1)
|
233 |
+
return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
|
234 |
+
|
235 |
+
|
236 |
+
class PairedAggregator(BaseAggregator):
|
237 |
+
|
238 |
+
def __init__(self, nonneg_sim, mask_silence, num_heads, head_agg, use_cls):
|
239 |
+
super().__init__(nonneg_sim, mask_silence, num_heads, head_agg, use_cls)
|
240 |
+
self.image_agg_max = lambda x, dim: x.max(dim=dim, keepdim=True).values
|
241 |
+
self.image_agg_mean = lambda x, dim: x.mean(dim=dim, keepdim=True)
|
242 |
+
|
243 |
+
self.time_agg_max = masked_max
|
244 |
+
self.time_agg_mean = masked_mean
|
245 |
+
|
246 |
+
self.freq_agg = lambda x, dim: x.mean(dim=dim, keepdim=True)
|
247 |
+
|
248 |
+
def _agg_sim(self, sim, mask):
|
249 |
+
sim_shape = sim.shape
|
250 |
+
new_mask_shape = [1] * len(sim_shape)
|
251 |
+
new_mask_shape[0] = sim_shape[0]
|
252 |
+
new_mask_shape[-1] = sim_shape[-1]
|
253 |
+
mask = mask.reshape(new_mask_shape)
|
254 |
+
|
255 |
+
sim_1 = self.image_agg_max(sim, -3)
|
256 |
+
sim_1 = self.image_agg_max(sim_1, -4)
|
257 |
+
sim_1 = self.freq_agg(sim_1, -2)
|
258 |
+
sim_1 = self.time_agg_mean(sim_1, mask, -1)
|
259 |
+
|
260 |
+
sim_2 = self.freq_agg(sim, -2)
|
261 |
+
sim_2 = self.time_agg_max(sim_2, mask, -1)
|
262 |
+
sim_2 = self.image_agg_mean(sim_2, -3)
|
263 |
+
sim_2 = self.image_agg_mean(sim_2, -4)
|
264 |
+
|
265 |
+
sim = 1 / 2 * (sim_1 + sim_2)
|
266 |
+
|
267 |
+
return sim.squeeze(-1).squeeze(-1).squeeze(-1).squeeze(-1)
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
class CAVMAEAggregator(BaseAggregator):
|
272 |
+
|
273 |
+
def __init__(self, *args, **kwargs):
|
274 |
+
super().__init__(False, False, 1, "sum", False)
|
275 |
+
|
276 |
+
def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
|
277 |
+
if agg_sim:
|
278 |
+
audio_feats = preds[AUDIO_FEATS]
|
279 |
+
image_feats = preds[IMAGE_FEATS]
|
280 |
+
pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
|
281 |
+
pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
|
282 |
+
sims = torch.einsum(
|
283 |
+
"bc,dc->bd",
|
284 |
+
pool_audio_feats.to(torch.float32),
|
285 |
+
pool_image_feats.to(torch.float32))
|
286 |
+
if agg_heads:
|
287 |
+
return sims
|
288 |
+
else:
|
289 |
+
return sims.unsqueeze(-1)
|
290 |
+
|
291 |
+
else:
|
292 |
+
return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
|
293 |
+
|
294 |
+
def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
|
295 |
+
if agg_sim:
|
296 |
+
audio_feats = preds[AUDIO_FEATS]
|
297 |
+
image_feats = preds[IMAGE_FEATS]
|
298 |
+
pool_audio_feats = F.normalize(audio_feats.mean(dim=[-1, -2]), dim=1)
|
299 |
+
pool_image_feats = F.normalize(image_feats.mean(dim=[-1, -2]), dim=1)
|
300 |
+
sims = torch.einsum(
|
301 |
+
"bc,bc->b",
|
302 |
+
pool_audio_feats.to(torch.float32),
|
303 |
+
pool_image_feats.to(torch.float32))
|
304 |
+
if agg_heads:
|
305 |
+
return sims
|
306 |
+
else:
|
307 |
+
return sims.unsqueeze(-1)
|
308 |
+
|
309 |
+
else:
|
310 |
+
return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
|
311 |
+
|
312 |
+
|
313 |
+
class ImageBindAggregator(BaseAggregator):
|
314 |
+
|
315 |
+
def __init__(self, num_heads, *args, **kwargs):
|
316 |
+
super().__init__(False, False, num_heads, "sum", False)
|
317 |
+
|
318 |
+
def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
|
319 |
+
if agg_sim:
|
320 |
+
sims = torch.einsum(
|
321 |
+
"bc,dc->bd",
|
322 |
+
preds[AUDIO_CLS].to(torch.float32),
|
323 |
+
preds[IMAGE_CLS].to(torch.float32))
|
324 |
+
if agg_heads:
|
325 |
+
return sims
|
326 |
+
else:
|
327 |
+
sims = sims.unsqueeze(-1)
|
328 |
+
return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
|
329 |
+
|
330 |
+
|
331 |
+
else:
|
332 |
+
return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
|
333 |
+
|
334 |
+
def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
|
335 |
+
if agg_sim:
|
336 |
+
sims = torch.einsum(
|
337 |
+
"bc,dc->b",
|
338 |
+
preds[AUDIO_CLS].to(torch.float32),
|
339 |
+
preds[IMAGE_CLS].to(torch.float32))
|
340 |
+
if agg_heads:
|
341 |
+
return sims
|
342 |
+
else:
|
343 |
+
sims = sims.unsqueeze(-1)
|
344 |
+
return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
|
345 |
+
|
346 |
+
else:
|
347 |
+
return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
|
348 |
+
|
349 |
+
def forward_batched(self, preds, agg_heads, batch_size):
|
350 |
+
return self.forward(preds, agg_heads)
|
351 |
+
|
352 |
+
|
353 |
+
class SimPool(nn.Module):
|
354 |
+
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, gamma=None, use_beta=False):
|
355 |
+
super().__init__()
|
356 |
+
self.num_heads = num_heads
|
357 |
+
head_dim = dim // num_heads
|
358 |
+
self.scale = qk_scale or head_dim ** -0.5
|
359 |
+
|
360 |
+
self.norm_patches = nn.LayerNorm(dim, eps=1e-6)
|
361 |
+
|
362 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
363 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
364 |
+
|
365 |
+
if gamma is not None:
|
366 |
+
self.gamma = torch.tensor([gamma])
|
367 |
+
if use_beta:
|
368 |
+
self.beta = nn.Parameter(torch.tensor([0.0]))
|
369 |
+
self.eps = torch.tensor([1e-6])
|
370 |
+
|
371 |
+
self.gamma = gamma
|
372 |
+
self.use_beta = use_beta
|
373 |
+
|
374 |
+
def prepare_input(self, x):
|
375 |
+
if len(x.shape) == 3: # Transformer
|
376 |
+
# Input tensor dimensions:
|
377 |
+
# x: (B, N, d), where B is batch size, N are patch tokens, d is depth (channels)
|
378 |
+
B, N, d = x.shape
|
379 |
+
gap_cls = x.mean(-2) # (B, N, d) -> (B, d)
|
380 |
+
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
|
381 |
+
return gap_cls, x
|
382 |
+
if len(x.shape) == 4: # CNN
|
383 |
+
# Input tensor dimensions:
|
384 |
+
# x: (B, d, H, W), where B is batch size, d is depth (channels), H is height, and W is width
|
385 |
+
B, d, H, W = x.shape
|
386 |
+
gap_cls = x.mean([-2, -1]) # (B, d, H, W) -> (B, d)
|
387 |
+
x = x.reshape(B, d, H * W).permute(0, 2, 1) # (B, d, H, W) -> (B, d, H*W) -> (B, H*W, d)
|
388 |
+
gap_cls = gap_cls.unsqueeze(1) # (B, d) -> (B, 1, d)
|
389 |
+
return gap_cls, x
|
390 |
+
else:
|
391 |
+
raise ValueError(f"Unsupported number of dimensions in input tensor: {len(x.shape)}")
|
392 |
+
|
393 |
+
def forward(self, x):
|
394 |
+
self.eps = self.eps.to(x.device)
|
395 |
+
# Prepare input tensor and perform GAP as initialization
|
396 |
+
gap_cls, x = self.prepare_input(x)
|
397 |
+
|
398 |
+
# Prepare queries (q), keys (k), and values (v)
|
399 |
+
q, k, v = gap_cls, self.norm_patches(x), self.norm_patches(x)
|
400 |
+
|
401 |
+
# Extract dimensions after normalization
|
402 |
+
Bq, Nq, dq = q.shape
|
403 |
+
Bk, Nk, dk = k.shape
|
404 |
+
Bv, Nv, dv = v.shape
|
405 |
+
|
406 |
+
# Check dimension consistency across batches and channels
|
407 |
+
assert Bq == Bk == Bv
|
408 |
+
assert dq == dk == dv
|
409 |
+
|
410 |
+
# Apply linear transformation for queries and keys then reshape
|
411 |
+
qq = self.wq(q).reshape(Bq, Nq, self.num_heads, dq // self.num_heads).permute(0, 2, 1,
|
412 |
+
3) # (Bq, Nq, dq) -> (B, num_heads, Nq, dq/num_heads)
|
413 |
+
kk = self.wk(k).reshape(Bk, Nk, self.num_heads, dk // self.num_heads).permute(0, 2, 1,
|
414 |
+
3) # (Bk, Nk, dk) -> (B, num_heads, Nk, dk/num_heads)
|
415 |
+
|
416 |
+
vv = v.reshape(Bv, Nv, self.num_heads, dv // self.num_heads).permute(0, 2, 1,
|
417 |
+
3) # (Bv, Nv, dv) -> (B, num_heads, Nv, dv/num_heads)
|
418 |
+
|
419 |
+
# Compute attention scores
|
420 |
+
attn = (qq @ kk.transpose(-2, -1)) * self.scale
|
421 |
+
# Apply softmax for normalization
|
422 |
+
attn = attn.softmax(dim=-1)
|
423 |
+
|
424 |
+
# If gamma scaling is used
|
425 |
+
if self.gamma is not None:
|
426 |
+
# Apply gamma scaling on values and compute the weighted sum using attention scores
|
427 |
+
x = torch.pow(attn @ torch.pow((vv - vv.min() + self.eps), self.gamma),
|
428 |
+
1 / self.gamma) # (B, num_heads, Nv, dv/num_heads) -> (B, 1, 1, d)
|
429 |
+
# If use_beta, add a learnable translation
|
430 |
+
if self.use_beta:
|
431 |
+
x = x + self.beta
|
432 |
+
else:
|
433 |
+
# Compute the weighted sum using attention scores
|
434 |
+
x = (attn @ vv).transpose(1, 2).reshape(Bq, Nq, dq)
|
435 |
+
|
436 |
+
return x.squeeze()
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
class SimPoolAggregator(BaseAggregator):
|
441 |
+
|
442 |
+
def __init__(self, num_heads, dim, *args, **kwargs):
|
443 |
+
super().__init__(False, False, num_heads, "sum", False)
|
444 |
+
self.pool = SimPool(dim, gamma=1.25)
|
445 |
+
|
446 |
+
def _get_full_sims(self, preds, raw, agg_sim, agg_heads):
|
447 |
+
if agg_sim:
|
448 |
+
device = self.pool.wq.weight.data.device
|
449 |
+
pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
|
450 |
+
pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
|
451 |
+
|
452 |
+
sims = torch.einsum(
|
453 |
+
"bc,dc->bd",
|
454 |
+
pooled_audio,
|
455 |
+
pooled_image)
|
456 |
+
if agg_heads:
|
457 |
+
return sims
|
458 |
+
else:
|
459 |
+
sims = sims.unsqueeze(-1)
|
460 |
+
return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
|
461 |
+
|
462 |
+
|
463 |
+
else:
|
464 |
+
return BaseAggregator._get_full_sims(self, preds, raw, agg_sim, agg_heads)
|
465 |
+
|
466 |
+
def get_pairwise_sims(self, preds, raw, agg_sim, agg_heads):
|
467 |
+
if agg_sim:
|
468 |
+
device = self.pool.wq.weight.data.device
|
469 |
+
pooled_audio = self.pool(preds[AUDIO_FEATS].to(torch.float32).to(device))
|
470 |
+
pooled_image = self.pool(preds[IMAGE_FEATS].to(torch.float32).to(device))
|
471 |
+
|
472 |
+
sims = torch.einsum(
|
473 |
+
"bc,dc->b",
|
474 |
+
pooled_audio,
|
475 |
+
pooled_image)
|
476 |
+
if agg_heads:
|
477 |
+
return sims
|
478 |
+
else:
|
479 |
+
sims = sims.unsqueeze(-1)
|
480 |
+
return sims.repeat(*([1] * (sims.dim() - 1)), self.num_heads)
|
481 |
+
|
482 |
+
else:
|
483 |
+
return BaseAggregator.get_pairwise_sims(self, preds, raw, agg_sim, agg_heads)
|
484 |
+
|
485 |
+
def forward_batched(self, preds, agg_heads, batch_size):
|
486 |
+
return self.forward(preds, agg_heads)
|
487 |
+
|
488 |
+
|
489 |
+
|
490 |
+
def get_aggregator(sim_agg_type, nonneg_sim, mask_silence, num_heads, head_agg, use_cls, dim):
|
491 |
+
shared_args = dict(
|
492 |
+
nonneg_sim=nonneg_sim,
|
493 |
+
mask_silence=mask_silence,
|
494 |
+
num_heads=num_heads,
|
495 |
+
head_agg=head_agg,
|
496 |
+
use_cls=use_cls,
|
497 |
+
)
|
498 |
+
|
499 |
+
if sim_agg_type == "paired":
|
500 |
+
agg1 = PairedAggregator(**shared_args)
|
501 |
+
elif sim_agg_type == "misa":
|
502 |
+
agg1 = ImageThenAudioAggregator("max", "avg", **shared_args)
|
503 |
+
elif sim_agg_type == "mima":
|
504 |
+
agg1 = ImageThenAudioAggregator("max", "max", **shared_args)
|
505 |
+
elif sim_agg_type == "sisa":
|
506 |
+
agg1 = ImageThenAudioAggregator("avg", "avg", **shared_args)
|
507 |
+
elif sim_agg_type == "cavmae":
|
508 |
+
agg1 = CAVMAEAggregator()
|
509 |
+
elif sim_agg_type == "imagebind":
|
510 |
+
agg1 = ImageBindAggregator(num_heads=shared_args["num_heads"])
|
511 |
+
elif sim_agg_type == "simpool":
|
512 |
+
agg1 = SimPoolAggregator(num_heads=shared_args["num_heads"], dim=dim)
|
513 |
+
else:
|
514 |
+
raise ValueError(f"Unknown loss_type {sim_agg_type}")
|
515 |
+
|
516 |
+
return agg1
|
517 |
+
|
DenseAV/denseav/aligners.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torch.nn import ModuleList
|
6 |
+
|
7 |
+
from denseav.featurizers.DINO import Block
|
8 |
+
|
9 |
+
|
10 |
+
class ChannelNorm(torch.nn.Module):
|
11 |
+
|
12 |
+
def __init__(self, dim, *args, **kwargs):
|
13 |
+
super().__init__(*args, **kwargs)
|
14 |
+
self.norm = torch.nn.LayerNorm(dim, eps=1e-4)
|
15 |
+
|
16 |
+
def forward_spatial(self, x):
|
17 |
+
return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
18 |
+
|
19 |
+
def forward(self, x, cls):
|
20 |
+
return self.forward_spatial(x), self.forward_cls(cls)
|
21 |
+
|
22 |
+
def forward_cls(self, cls):
|
23 |
+
if cls is not None:
|
24 |
+
return self.norm(cls)
|
25 |
+
else:
|
26 |
+
return None
|
27 |
+
|
28 |
+
|
29 |
+
def id_conv(dim, strength=.9):
|
30 |
+
conv = torch.nn.Conv2d(dim, dim, 1, padding="same")
|
31 |
+
start_w = conv.weight.data
|
32 |
+
conv.weight.data = torch.nn.Parameter(
|
33 |
+
torch.eye(dim, device=start_w.device).unsqueeze(-1).unsqueeze(-1) * strength + start_w * (1 - strength))
|
34 |
+
conv.bias.data = torch.nn.Parameter(conv.bias.data * (1 - strength))
|
35 |
+
return conv
|
36 |
+
|
37 |
+
|
38 |
+
class LinearAligner(torch.nn.Module):
|
39 |
+
def __init__(self, in_dim, out_dim, use_norm=True):
|
40 |
+
super().__init__()
|
41 |
+
self.in_dim = in_dim
|
42 |
+
self.out_dim = out_dim
|
43 |
+
if use_norm:
|
44 |
+
self.norm = ChannelNorm(in_dim)
|
45 |
+
else:
|
46 |
+
self.norm = Identity2()
|
47 |
+
|
48 |
+
if in_dim == out_dim:
|
49 |
+
self.layer = id_conv(in_dim, 0)
|
50 |
+
else:
|
51 |
+
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1)
|
52 |
+
|
53 |
+
self.cls_layer = torch.nn.Linear(in_dim, out_dim)
|
54 |
+
|
55 |
+
def forward(self, spatial, cls):
|
56 |
+
norm_spatial, norm_cls = self.norm(spatial, cls)
|
57 |
+
|
58 |
+
if cls is not None:
|
59 |
+
aligned_cls = self.cls_layer(cls)
|
60 |
+
else:
|
61 |
+
aligned_cls = None
|
62 |
+
|
63 |
+
return self.layer(norm_spatial), aligned_cls
|
64 |
+
|
65 |
+
class IdLinearAligner(torch.nn.Module):
|
66 |
+
def __init__(self, in_dim, out_dim):
|
67 |
+
super().__init__()
|
68 |
+
self.in_dim = in_dim
|
69 |
+
self.out_dim = out_dim
|
70 |
+
assert self.out_dim == self.in_dim
|
71 |
+
self.layer = id_conv(in_dim, 1.0)
|
72 |
+
def forward(self, spatial, cls):
|
73 |
+
return self.layer(spatial), cls
|
74 |
+
|
75 |
+
|
76 |
+
class FrequencyAvg(torch.nn.Module):
|
77 |
+
def __init__(self):
|
78 |
+
super().__init__()
|
79 |
+
|
80 |
+
def forward(self, spatial, cls):
|
81 |
+
return spatial.mean(2, keepdim=True), cls
|
82 |
+
|
83 |
+
|
84 |
+
class LearnedTimePool(torch.nn.Module):
|
85 |
+
def __init__(self, dim, width, maxpool):
|
86 |
+
super().__init__()
|
87 |
+
self.dim = dim
|
88 |
+
self.width = width
|
89 |
+
self.norm = ChannelNorm(dim)
|
90 |
+
if maxpool:
|
91 |
+
self.layer = torch.nn.Sequential(
|
92 |
+
torch.nn.Conv2d(dim, dim, kernel_size=width, stride=1, padding="same"),
|
93 |
+
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
self.layer = torch.nn.Conv2d(dim, dim, kernel_size=(1, width), stride=(1, width))
|
97 |
+
|
98 |
+
def forward(self, spatial, cls):
|
99 |
+
norm_spatial, norm_cls = self.norm(spatial, cls)
|
100 |
+
return self.layer(norm_spatial), norm_cls
|
101 |
+
|
102 |
+
|
103 |
+
class LearnedTimePool2(torch.nn.Module):
|
104 |
+
def __init__(self, in_dim, out_dim, width, maxpool, use_cls_layer):
|
105 |
+
super().__init__()
|
106 |
+
self.in_dim = in_dim
|
107 |
+
self.out_dim = out_dim
|
108 |
+
self.width = width
|
109 |
+
|
110 |
+
if maxpool:
|
111 |
+
self.layer = torch.nn.Sequential(
|
112 |
+
torch.nn.Conv2d(in_dim, out_dim, kernel_size=width, stride=1, padding="same"),
|
113 |
+
torch.nn.MaxPool2d(kernel_size=(1, width), stride=(1, width))
|
114 |
+
)
|
115 |
+
else:
|
116 |
+
self.layer = torch.nn.Conv2d(in_dim, out_dim, kernel_size=(1, width), stride=(1, width))
|
117 |
+
|
118 |
+
self.use_cls_layer = use_cls_layer
|
119 |
+
if use_cls_layer:
|
120 |
+
self.cls_layer = torch.nn.Linear(in_dim, out_dim)
|
121 |
+
|
122 |
+
def forward(self, spatial, cls):
|
123 |
+
|
124 |
+
if cls is not None:
|
125 |
+
if self.use_cls_layer:
|
126 |
+
aligned_cls = self.cls_layer(cls)
|
127 |
+
else:
|
128 |
+
aligned_cls = cls
|
129 |
+
else:
|
130 |
+
aligned_cls = None
|
131 |
+
|
132 |
+
return self.layer(spatial), aligned_cls
|
133 |
+
|
134 |
+
|
135 |
+
class Sequential2(torch.nn.Module):
|
136 |
+
|
137 |
+
def __init__(self, *modules):
|
138 |
+
super().__init__()
|
139 |
+
self.mod_list = ModuleList(modules)
|
140 |
+
|
141 |
+
def forward(self, x, y):
|
142 |
+
results = (x, y)
|
143 |
+
for m in self.mod_list:
|
144 |
+
results = m(*results)
|
145 |
+
return results
|
146 |
+
|
147 |
+
|
148 |
+
class ProgressiveGrowing(torch.nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, stages, phase_lengths):
|
151 |
+
super().__init__()
|
152 |
+
self.stages = torch.nn.ModuleList(stages)
|
153 |
+
self.phase_lengths = torch.tensor(phase_lengths)
|
154 |
+
assert len(self.phase_lengths) + 1 == len(self.stages)
|
155 |
+
self.phase_boundaries = self.phase_lengths.cumsum(0)
|
156 |
+
self.register_buffer('phase', torch.tensor([1]))
|
157 |
+
|
158 |
+
def maybe_change_phase(self, global_step):
|
159 |
+
needed_phase = (global_step >= self.phase_boundaries).to(torch.int64).sum().item() + 1
|
160 |
+
if needed_phase != self.phase.item():
|
161 |
+
print(f"Changing aligner phase to {needed_phase}")
|
162 |
+
self.phase.copy_(torch.tensor([needed_phase]).to(self.phase.device))
|
163 |
+
return True
|
164 |
+
else:
|
165 |
+
return False
|
166 |
+
|
167 |
+
def parameters(self, recurse: bool = True):
|
168 |
+
phase = self.phase.item()
|
169 |
+
used_stages = self.stages[:phase]
|
170 |
+
print(f"Progressive Growing at stage {phase}")
|
171 |
+
all_params = []
|
172 |
+
for stage in used_stages:
|
173 |
+
all_params.extend(stage.parameters(recurse))
|
174 |
+
return iter(all_params)
|
175 |
+
|
176 |
+
def forward(self, spatial, cls):
|
177 |
+
pipeline = Sequential2(*self.stages[:self.phase.item()])
|
178 |
+
return pipeline(spatial, cls)
|
179 |
+
|
180 |
+
|
181 |
+
class Identity2(torch.nn.Module):
|
182 |
+
|
183 |
+
def __init__(self):
|
184 |
+
super().__init__()
|
185 |
+
|
186 |
+
def forward(self, x, y):
|
187 |
+
return x, y
|
188 |
+
|
189 |
+
|
190 |
+
class SelfAttentionAligner(torch.nn.Module):
|
191 |
+
|
192 |
+
def __init__(self, dim):
|
193 |
+
super().__init__()
|
194 |
+
self.dim = dim
|
195 |
+
|
196 |
+
self.num_heads = 6
|
197 |
+
if dim % self.num_heads != 0:
|
198 |
+
self.padding = self.num_heads - (dim % self.num_heads)
|
199 |
+
else:
|
200 |
+
self.padding = 0
|
201 |
+
|
202 |
+
self.block = Block(
|
203 |
+
dim + self.padding,
|
204 |
+
num_heads=self.num_heads,
|
205 |
+
mlp_ratio=4,
|
206 |
+
qkv_bias=True,
|
207 |
+
qk_scale=None,
|
208 |
+
drop=0.0,
|
209 |
+
attn_drop=0.0,
|
210 |
+
drop_path=0.0,
|
211 |
+
norm_layer=partial(torch.nn.LayerNorm, eps=1e-4))
|
212 |
+
|
213 |
+
def forward(self, spatial, cls):
|
214 |
+
padded_feats = F.pad(spatial, [0, 0, 0, 0, self.padding, 0])
|
215 |
+
|
216 |
+
B, C, H, W = padded_feats.shape
|
217 |
+
proj_feats = padded_feats.reshape(B, C, H * W).permute(0, 2, 1)
|
218 |
+
|
219 |
+
if cls is not None:
|
220 |
+
assert len(cls.shape) == 2
|
221 |
+
padded_cls = F.pad(cls, [self.padding, 0])
|
222 |
+
proj_feats = torch.cat([padded_cls.unsqueeze(1), proj_feats], dim=1)
|
223 |
+
|
224 |
+
aligned_feat, attn, qkv = self.block(proj_feats, return_qkv=True)
|
225 |
+
|
226 |
+
if cls is not None:
|
227 |
+
aligned_cls = aligned_feat[:, 0, :]
|
228 |
+
aligned_spatial = aligned_feat[:, 1:, :]
|
229 |
+
else:
|
230 |
+
aligned_cls = None
|
231 |
+
aligned_spatial = aligned_feat
|
232 |
+
|
233 |
+
aligned_spatial = aligned_spatial.reshape(B, H, W, self.dim + self.padding).permute(0, 3, 1, 2)
|
234 |
+
|
235 |
+
aligned_spatial = aligned_spatial[:, self.padding:, :, :]
|
236 |
+
if aligned_cls is not None:
|
237 |
+
aligned_cls = aligned_cls[:, self.padding:]
|
238 |
+
|
239 |
+
return aligned_spatial, aligned_cls
|
240 |
+
|
241 |
+
|
242 |
+
def get_aligner(aligner_type, in_dim, out_dim, **kwargs):
|
243 |
+
if aligner_type is None:
|
244 |
+
return Identity2()
|
245 |
+
|
246 |
+
if "prog" in aligner_type:
|
247 |
+
phase_length = kwargs["phase_length"]
|
248 |
+
|
249 |
+
if aligner_type == "image_linear":
|
250 |
+
return LinearAligner(in_dim, out_dim)
|
251 |
+
elif aligner_type == "image_idlinear":
|
252 |
+
return IdLinearAligner(in_dim, out_dim)
|
253 |
+
elif aligner_type == "image_linear_no_norm":
|
254 |
+
return LinearAligner(in_dim, out_dim, use_norm=False)
|
255 |
+
elif aligner_type == "image_id":
|
256 |
+
return Identity2()
|
257 |
+
elif aligner_type == "image_norm":
|
258 |
+
return ChannelNorm(in_dim)
|
259 |
+
elif aligner_type == "audio_linear":
|
260 |
+
return Sequential2(
|
261 |
+
LinearAligner(in_dim, out_dim),
|
262 |
+
FrequencyAvg())
|
263 |
+
elif aligner_type == "audio_sa":
|
264 |
+
return Sequential2(
|
265 |
+
LinearAligner(in_dim, out_dim),
|
266 |
+
FrequencyAvg(),
|
267 |
+
SelfAttentionAligner(out_dim)
|
268 |
+
)
|
269 |
+
elif aligner_type == "audio_sa_sa":
|
270 |
+
return Sequential2(
|
271 |
+
FrequencyAvg(),
|
272 |
+
LinearAligner(in_dim, out_dim),
|
273 |
+
SelfAttentionAligner(out_dim),
|
274 |
+
SelfAttentionAligner(out_dim)
|
275 |
+
)
|
276 |
+
elif aligner_type == "audio_3_3_pool":
|
277 |
+
return Sequential2(
|
278 |
+
LinearAligner(in_dim, out_dim),
|
279 |
+
FrequencyAvg(),
|
280 |
+
LearnedTimePool(out_dim, 3, False),
|
281 |
+
LearnedTimePool(out_dim, 3, False),
|
282 |
+
)
|
283 |
+
elif aligner_type == "audio_sa_3_3_pool":
|
284 |
+
return Sequential2(
|
285 |
+
LinearAligner(in_dim, out_dim),
|
286 |
+
FrequencyAvg(),
|
287 |
+
LearnedTimePool(out_dim, 3, False),
|
288 |
+
LearnedTimePool(out_dim, 3, False),
|
289 |
+
SelfAttentionAligner(out_dim)
|
290 |
+
)
|
291 |
+
elif aligner_type == "audio_sa_3_3_pool_2":
|
292 |
+
return Sequential2(
|
293 |
+
FrequencyAvg(),
|
294 |
+
ChannelNorm(in_dim),
|
295 |
+
LearnedTimePool2(in_dim, out_dim, 3, False, True),
|
296 |
+
LearnedTimePool2(out_dim, out_dim, 3, False, False),
|
297 |
+
SelfAttentionAligner(out_dim)
|
298 |
+
)
|
299 |
+
else:
|
300 |
+
raise ValueError(f"Unknown aligner type {aligner_type}")
|
DenseAV/denseav/configs/av_align.yaml
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model args
|
2 |
+
|
3 |
+
code_dim: 384
|
4 |
+
image_model_type: "dino8"
|
5 |
+
image_model_token_type: "token"
|
6 |
+
image_aligner_type: "image_linear"
|
7 |
+
image_pool_width: 2
|
8 |
+
|
9 |
+
audio_model_type: "hubert"
|
10 |
+
audio_aligner_type: "audio_sa_3_3_pool_2"
|
11 |
+
audio_pool_width: 1
|
12 |
+
|
13 |
+
learn_audio_cls: True
|
14 |
+
|
15 |
+
#code_dim: 1024
|
16 |
+
#image_model_type: "imagebind"
|
17 |
+
#image_model_token_type: "token"
|
18 |
+
#image_aligner_type: "image_linear"
|
19 |
+
#image_pool_width: 1
|
20 |
+
#
|
21 |
+
#audio_model_type: "imagebind"
|
22 |
+
#audio_aligner_type: "audio_sa"
|
23 |
+
#audio_pool_width: 1
|
24 |
+
#
|
25 |
+
#learn_audio_cls: False
|
26 |
+
|
27 |
+
audio_lora: False
|
28 |
+
audio_lora_rank: 8
|
29 |
+
image_lora: True
|
30 |
+
image_lora_rank: 8
|
31 |
+
|
32 |
+
|
33 |
+
spatial_dropout: 0.0
|
34 |
+
channel_dropout: 0.0
|
35 |
+
|
36 |
+
quad_mixup: 0.1
|
37 |
+
bg_mixup: 0.0
|
38 |
+
patch_mixup: 0.0
|
39 |
+
mixup_weight: 0.1
|
40 |
+
|
41 |
+
sim_agg_type: "misa"
|
42 |
+
sim_agg_heads: 1
|
43 |
+
sim_use_cls: False
|
44 |
+
|
45 |
+
cal_init: 1.0
|
46 |
+
cal_balance_weight: 0.1
|
47 |
+
nonneg_sim: False
|
48 |
+
nonneg_pressure: 0.01
|
49 |
+
silence_l1: 0.01
|
50 |
+
silence_l2: 0.0
|
51 |
+
tv_weight: 0.01
|
52 |
+
specialization_weight: 0.05
|
53 |
+
head_agg: "max_elementwise"
|
54 |
+
disentangle_weight: 0.0
|
55 |
+
|
56 |
+
norm_vectors: False
|
57 |
+
|
58 |
+
neg_audio: true
|
59 |
+
neg_audio_weight: 0.01
|
60 |
+
|
61 |
+
|
62 |
+
pretrain_steps: 3000
|
63 |
+
pretrain_lr: .5e-4
|
64 |
+
|
65 |
+
# Loss args
|
66 |
+
lr: .5e-4
|
67 |
+
lr_warmup: 1000
|
68 |
+
|
69 |
+
#lr_warmup: 100
|
70 |
+
|
71 |
+
lr_schedule: ~
|
72 |
+
lr_cycle_length: 50000
|
73 |
+
|
74 |
+
optimizer: "adam"
|
75 |
+
gradient_clipping: 10.0
|
76 |
+
adaptive_clipping: True
|
77 |
+
gather_tensors: True
|
78 |
+
loss_type: "nce"
|
79 |
+
loss_leak: 0.0
|
80 |
+
loss_margin: 0.0
|
81 |
+
mask_silence: true
|
82 |
+
extra_audio_masking: true
|
83 |
+
max_steps: 1000001
|
84 |
+
|
85 |
+
finetune_image_model: False
|
86 |
+
finetune_audio_model: True
|
87 |
+
|
88 |
+
# Checkpointing args
|
89 |
+
load_strict: true
|
90 |
+
starting_weights: ~
|
91 |
+
auto_resume: false
|
92 |
+
grouping_name: "foo"
|
93 |
+
resume_prefix: "imagebind_exp2"
|
94 |
+
|
95 |
+
# Data Args
|
96 |
+
#dataset_name: "sample-audio"
|
97 |
+
dataset_name: "places-audio"
|
98 |
+
#dataset_name: "mixed"
|
99 |
+
#dataset_name: "audio-set-full"
|
100 |
+
use_extra_val_sets: true
|
101 |
+
batch_size: 10
|
102 |
+
load_size: 224
|
103 |
+
image_aug: true
|
104 |
+
audio_aug: false
|
105 |
+
|
106 |
+
audio_level: false
|
107 |
+
|
108 |
+
memory_buffer_size: 0
|
109 |
+
|
110 |
+
val_check_interval: 10000 #0
|
111 |
+
use_cached_embs: false
|
112 |
+
num_workers: 12
|
113 |
+
num_gpus: 4
|
114 |
+
num_sanity_val_steps: 0 #-1
|
115 |
+
seed: 0
|
116 |
+
|
117 |
+
# Env args
|
118 |
+
output_root: '../'
|
119 |
+
pytorch_data_dir: '/pytorch-data/'
|
120 |
+
submitting_to_aml: false
|
121 |
+
|
122 |
+
hydra:
|
123 |
+
run:
|
124 |
+
dir: "."
|
125 |
+
output_subdir: ~
|
DenseAV/denseav/constants.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
IMAGE_INPUT = "frames"
|
3 |
+
IMAGE_FEATS = "image_feats"
|
4 |
+
IMAGE_CLS = "image_cls"
|
5 |
+
IMAGE_MASK = "image_masks"
|
6 |
+
|
7 |
+
AUDIO_FEATS = "audio_feats"
|
8 |
+
AUDIO_CLS = "audio_cls"
|
9 |
+
AUDIO_MASK = "audio_mask"
|
10 |
+
AUDIO_POS_MASK = "audio_pos_mask"
|
11 |
+
|
12 |
+
DATA_SOURCE = "source"
|
DenseAV/denseav/data/AVDatasets.py
ADDED
@@ -0,0 +1,1249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from glob import glob
|
5 |
+
from os.path import join
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import List, Set
|
8 |
+
|
9 |
+
import audioread
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
import torchaudio
|
16 |
+
import torchvision.transforms as T
|
17 |
+
from PIL import Image
|
18 |
+
from torch.utils.data import Dataset, DataLoader, default_collate, Subset, ConcatDataset
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from denseav.constants import AUDIO_MASK, AUDIO_POS_MASK, IMAGE_MASK, IMAGE_INPUT
|
22 |
+
from denseav.data.make_tarballs import untar_all
|
23 |
+
from denseav.shared import norm, prep_waveform
|
24 |
+
|
25 |
+
|
26 |
+
def sample_choice(choices, probs):
|
27 |
+
# Check that probabilities sum to 1 and are non-negative
|
28 |
+
assert sum(probs) == 1, "Probabilities must sum to 1"
|
29 |
+
assert all(p >= 0 for p in probs), "Probabilities cannot be negative"
|
30 |
+
|
31 |
+
# Convert probs to a tensor
|
32 |
+
probs_tensor = torch.tensor(probs)
|
33 |
+
|
34 |
+
# Sample a choice according to the probabilities
|
35 |
+
index = torch.multinomial(probs_tensor, 1).item()
|
36 |
+
|
37 |
+
# Return the sampled choice
|
38 |
+
return choices[index]
|
39 |
+
|
40 |
+
|
41 |
+
def grid_frames(frames):
|
42 |
+
top_row = torch.cat([frames[0], frames[1]], dim=2)
|
43 |
+
bottom_row = torch.cat([frames[2], frames[3]], dim=2)
|
44 |
+
return torch.cat([top_row, bottom_row], dim=3)
|
45 |
+
|
46 |
+
|
47 |
+
def create_mixed_image(pos_frame, neg_frame, patch_size):
|
48 |
+
# Step 1: Check that patch_size evenly divides the image dimensions
|
49 |
+
b, c, h, w = pos_frame.shape
|
50 |
+
assert h % patch_size == 0 and w % patch_size == 0, "Patch size must evenly divide image dimensions"
|
51 |
+
|
52 |
+
# Step 2: Create a random binary mask with the same number of patches as the image
|
53 |
+
mask = torch.randint(0, 2, (b, 1, h // patch_size, w // patch_size))
|
54 |
+
|
55 |
+
# Step 3: Create a new image using patches from pos_frame and neg_frame according to the mask
|
56 |
+
# Upscale the mask to the size of the image
|
57 |
+
mask_upscaled = F.interpolate(mask.to(torch.float32), scale_factor=patch_size)
|
58 |
+
|
59 |
+
# Use the mask to create a mixed frame
|
60 |
+
mixed_frame = mask_upscaled * pos_frame + (1 - mask_upscaled) * neg_frame
|
61 |
+
|
62 |
+
return mixed_frame, mask_upscaled
|
63 |
+
|
64 |
+
|
65 |
+
class AVDataset(ABC, Dataset):
|
66 |
+
|
67 |
+
@abstractmethod
|
68 |
+
def _dataset_folder(self) -> str:
|
69 |
+
pass
|
70 |
+
|
71 |
+
@abstractmethod
|
72 |
+
def _load_info(self, split) -> pd.DataFrame:
|
73 |
+
"""
|
74 |
+
This function should return a dataframe with at least a column "id"
|
75 |
+
@return:
|
76 |
+
"""
|
77 |
+
pass
|
78 |
+
|
79 |
+
@abstractmethod
|
80 |
+
def _missing_threshold(self) -> float:
|
81 |
+
pass
|
82 |
+
|
83 |
+
@abstractmethod
|
84 |
+
def default_target_length(self) -> int:
|
85 |
+
pass
|
86 |
+
|
87 |
+
def target_length(self):
|
88 |
+
if self.override_target_length is not None:
|
89 |
+
return self.override_target_length
|
90 |
+
else:
|
91 |
+
return self.default_target_length()
|
92 |
+
|
93 |
+
def _frame_root(self) -> str:
|
94 |
+
return join(self.root, "frames", self.split)
|
95 |
+
|
96 |
+
def _video_root(self) -> str:
|
97 |
+
return join(self.root, "videos", self.split)
|
98 |
+
|
99 |
+
def _audio_root(self) -> str:
|
100 |
+
return join(self.root, "audio", self.split)
|
101 |
+
|
102 |
+
def _semseg_root(self) -> str:
|
103 |
+
return join(self.root, "annotations", self.split)
|
104 |
+
|
105 |
+
def _embed_root(self) -> str:
|
106 |
+
return join(self.root, "embedding", self.audio_embed_model, self.split)
|
107 |
+
|
108 |
+
def _label_root(self) -> str:
|
109 |
+
return join(self.root, "pseudo-labels")
|
110 |
+
|
111 |
+
def _hn_root(self) -> str:
|
112 |
+
return join(self.root, "hard_negatives")
|
113 |
+
|
114 |
+
def _all_video_files(self) -> Set[str]:
|
115 |
+
return set(str(p) for p in Path(join(self._video_root())).rglob('*'))
|
116 |
+
|
117 |
+
def _all_frame_files(self) -> Set[str]:
|
118 |
+
return set(str(p) for p in Path(join(self._frame_root())).rglob('*'))
|
119 |
+
|
120 |
+
def _all_audio_files(self) -> Set[str]:
|
121 |
+
return set(str(p) for p in Path(join(self._audio_root())).rglob('*'))
|
122 |
+
|
123 |
+
def _all_embed_files(self) -> Set[str]:
|
124 |
+
return set(str(p) for p in Path(join(self._embed_root())).rglob('*'))
|
125 |
+
|
126 |
+
def _get_frame_files(self, row) -> List[str]:
|
127 |
+
return [self._frame_root() + "/" + row["id"] + f"_{i}.jpg" for i in range(self._expected_num_frames())]
|
128 |
+
|
129 |
+
def _get_semseg_file(self, row) -> str:
|
130 |
+
raise NotImplementedError("Class has not implemented _get_semseg_files")
|
131 |
+
|
132 |
+
def _get_audio_file(self, row) -> str:
|
133 |
+
return self._audio_root() + "/" + row["id"] + ".mp3"
|
134 |
+
|
135 |
+
def _get_video_file(self, row) -> str:
|
136 |
+
return self._video_root() + "/" + row["id"] + ".mp4"
|
137 |
+
|
138 |
+
def _get_embed_file(self, row) -> str:
|
139 |
+
return self._embed_root() + "/" + row["id"] + ".npz"
|
140 |
+
|
141 |
+
def _add_files_to_metadata(self, df) -> pd.DataFrame:
|
142 |
+
tqdm.pandas()
|
143 |
+
|
144 |
+
if self.use_audio_embed:
|
145 |
+
df["embed_file"] = df.progress_apply(self._get_embed_file, axis=1)
|
146 |
+
|
147 |
+
if self.use_audio or self.use_spec:
|
148 |
+
df["audio_file"] = df.progress_apply(self._get_audio_file, axis=1)
|
149 |
+
|
150 |
+
if self.use_frames:
|
151 |
+
df["frame_files"] = df.progress_apply(self._get_frame_files, axis=1)
|
152 |
+
|
153 |
+
if self.use_semseg:
|
154 |
+
df["semseg_file"] = df.progress_apply(self._get_semseg_file, axis=1)
|
155 |
+
|
156 |
+
df = self._filter_valid_metadata(df)
|
157 |
+
|
158 |
+
if self.use_hn:
|
159 |
+
loaded = np.load(join(self._hn_root(), "original", f"{self.split}_hard_negatives.npz"))
|
160 |
+
df["hn0"] = [t for t in torch.tensor(loaded["indices_0"])]
|
161 |
+
df["hn1"] = [t for t in torch.tensor(loaded["indices_1"])]
|
162 |
+
|
163 |
+
return df
|
164 |
+
|
165 |
+
def _split_name(self, split):
|
166 |
+
return split
|
167 |
+
|
168 |
+
def _filter_valid_metadata(self, df: pd.DataFrame) -> pd.DataFrame:
|
169 |
+
|
170 |
+
print("MY_DIR ", list(glob(join(self.root, "*"))))
|
171 |
+
if self.use_audio_embed:
|
172 |
+
missing_embed_files = set(df['embed_file']) - self.all_embed_files
|
173 |
+
valid_audio = ~df['embed_file'].isin(missing_embed_files)
|
174 |
+
print("ALL EMBED ", len(self.all_embed_files))
|
175 |
+
elif self.use_audio or self.use_spec:
|
176 |
+
missing_audio_files = set(df['audio_file']) - self.all_audio_files
|
177 |
+
valid_audio = ~df['audio_file'].isin(missing_audio_files)
|
178 |
+
print("ALL AUDIO ", len(self.all_audio_files))
|
179 |
+
|
180 |
+
if self.use_frames:
|
181 |
+
missing_frame_files = set(
|
182 |
+
item for sublist in df['frame_files'].tolist() for item in sublist) - self.all_frame_files
|
183 |
+
valid_frames = df['frame_files'].apply(lambda x: not any(file in missing_frame_files for file in x))
|
184 |
+
print("ALL FRAMES ", len(self.all_frame_files))
|
185 |
+
df["is_valid"] = valid_audio & valid_frames
|
186 |
+
else:
|
187 |
+
df["is_valid"] = valid_audio
|
188 |
+
|
189 |
+
percent_missing = (1 - (df["is_valid"].sum() / len(df)))
|
190 |
+
|
191 |
+
assert percent_missing <= self._missing_threshold(), \
|
192 |
+
f"Too many missing files: %{round(percent_missing * 100.0, 2)}"
|
193 |
+
assert len(df) > 0, "No files found"
|
194 |
+
return df[df["is_valid"]]
|
195 |
+
|
196 |
+
def __init__(
|
197 |
+
self,
|
198 |
+
root: str,
|
199 |
+
split: str = "train",
|
200 |
+
use_frames=False,
|
201 |
+
frame_transform=None,
|
202 |
+
use_audio=False,
|
203 |
+
use_spec=False,
|
204 |
+
use_audio_embed=False,
|
205 |
+
use_hn=False,
|
206 |
+
use_caption=False,
|
207 |
+
use_semseg=False,
|
208 |
+
neg_audio=False,
|
209 |
+
use_davenet_spec=False,
|
210 |
+
use_fnac_spec=False,
|
211 |
+
n_label_frames=196,
|
212 |
+
label_transform=None,
|
213 |
+
audio_embed_model="hubert",
|
214 |
+
n_frames=1,
|
215 |
+
audio_transform=None,
|
216 |
+
audio_aug=False,
|
217 |
+
spec_transform=None,
|
218 |
+
spec_mel_bins=128,
|
219 |
+
spec_mean=-6.6268077,
|
220 |
+
spec_std=5.358466,
|
221 |
+
sample_rate=16000,
|
222 |
+
override_target_length=None,
|
223 |
+
use_tags=False,
|
224 |
+
extra_audio_masking=False,
|
225 |
+
audio_level=False,
|
226 |
+
quad_mixup=0.0,
|
227 |
+
bg_mixup=0.0,
|
228 |
+
patch_mixup=0.0,
|
229 |
+
patch_size=8,
|
230 |
+
):
|
231 |
+
super(AVDataset).__init__()
|
232 |
+
self.pytorch_data_dir = root
|
233 |
+
self.split = self._split_name(split)
|
234 |
+
self.root = join(root, self._dataset_folder())
|
235 |
+
self.use_frames = use_frames
|
236 |
+
self.frame_transform = frame_transform
|
237 |
+
self.use_audio = use_audio
|
238 |
+
self.use_spec = use_spec
|
239 |
+
self.use_audio_embed = use_audio_embed
|
240 |
+
self.use_davenet_spec = use_davenet_spec
|
241 |
+
self.use_fnac_spec = use_fnac_spec
|
242 |
+
self.use_hn = use_hn
|
243 |
+
self.use_caption = use_caption
|
244 |
+
self.label_transform = label_transform
|
245 |
+
self.audio_embed_model = audio_embed_model
|
246 |
+
self.audio_aug = audio_aug
|
247 |
+
self.n_frames = n_frames
|
248 |
+
self.audio_transform = audio_transform
|
249 |
+
self.spec_transform = spec_transform
|
250 |
+
self.spec_mel_bins = spec_mel_bins
|
251 |
+
self.spec_mean = spec_mean
|
252 |
+
self.spec_std = spec_std
|
253 |
+
self.use_semseg = use_semseg
|
254 |
+
self.override_target_length = override_target_length
|
255 |
+
self.use_tags = use_tags
|
256 |
+
self.extra_audio_masking = extra_audio_masking
|
257 |
+
self.neg_audio = neg_audio
|
258 |
+
self.audio_level = audio_level
|
259 |
+
|
260 |
+
self.quad_mixup = quad_mixup
|
261 |
+
self.bg_mixup = bg_mixup
|
262 |
+
self.patch_mixup = patch_mixup
|
263 |
+
self.patch_size = patch_size
|
264 |
+
|
265 |
+
self.sample_rate = sample_rate
|
266 |
+
self.n_label_frames = n_label_frames
|
267 |
+
|
268 |
+
if self.use_audio_embed:
|
269 |
+
self.all_embed_files = self._all_embed_files()
|
270 |
+
|
271 |
+
if self.use_audio or self.use_spec:
|
272 |
+
self.all_audio_files = self._all_audio_files()
|
273 |
+
|
274 |
+
if self.use_frames:
|
275 |
+
self.all_frame_files = self._all_frame_files()
|
276 |
+
|
277 |
+
self.metadata = self._add_files_to_metadata(self._load_info(self.split))
|
278 |
+
|
279 |
+
assert len(self.metadata) > 0
|
280 |
+
|
281 |
+
def __len__(self):
|
282 |
+
return len(self.metadata)
|
283 |
+
|
284 |
+
@abstractmethod
|
285 |
+
def _expected_num_frames(self) -> int:
|
286 |
+
pass
|
287 |
+
|
288 |
+
def get_audio_mask(self, real_length, padded_length, target_size):
|
289 |
+
if not isinstance(real_length, torch.Tensor):
|
290 |
+
real_length = torch.tensor(real_length)
|
291 |
+
padded_length = torch.tensor(padded_length)
|
292 |
+
|
293 |
+
n_frames = ((real_length / padded_length) * target_size).to(torch.int64)
|
294 |
+
oh = F.one_hot(n_frames, num_classes=target_size + 1)
|
295 |
+
if len(oh.shape) == 1:
|
296 |
+
oh = oh.unsqueeze(0)
|
297 |
+
return (1 - torch.cumsum(oh, dim=1))[:, :-1].to(torch.bool)
|
298 |
+
|
299 |
+
def _base_get_item(self, item):
|
300 |
+
id = self.metadata["id"].iloc[item]
|
301 |
+
data_dict = {"metadata": {"id": id, "index": item}}
|
302 |
+
|
303 |
+
if self.use_tags and "tags" in self.metadata:
|
304 |
+
tags = torch.tensor(self.metadata["tags"].iloc[item])
|
305 |
+
tag_oh = torch.zeros(self.num_tags, dtype=torch.float32)
|
306 |
+
tag_oh[tags] += 1
|
307 |
+
data_dict["tags"] = tag_oh
|
308 |
+
|
309 |
+
if self.use_audio or self.use_spec:
|
310 |
+
audio_file = self.metadata["audio_file"].iloc[item]
|
311 |
+
data_dict["metadata"]["audio_file"] = audio_file
|
312 |
+
loaded_waveform, obs_sr = torchaudio.load(audio_file)
|
313 |
+
loaded_waveform = loaded_waveform[0]
|
314 |
+
|
315 |
+
if self.neg_audio:
|
316 |
+
neg_audio_file = self.metadata["audio_file"].iloc[torch.randint(0, len(self), size=(1,)).item()]
|
317 |
+
data_dict["metadata"]["neg_audio_file"] = neg_audio_file
|
318 |
+
neg_waveform, neg_obs_sr = torchaudio.load(neg_audio_file)
|
319 |
+
neg_waveform = neg_waveform[0]
|
320 |
+
else:
|
321 |
+
neg_waveform, neg_obs_sr = None, None
|
322 |
+
|
323 |
+
(waveform,
|
324 |
+
spectrogram,
|
325 |
+
audio_length,
|
326 |
+
total_length,
|
327 |
+
original_length,
|
328 |
+
mask,
|
329 |
+
pos_mask) = prep_waveform(
|
330 |
+
loaded_waveform,
|
331 |
+
obs_sr,
|
332 |
+
self.target_length(),
|
333 |
+
self.spec_mel_bins,
|
334 |
+
self.spec_mean,
|
335 |
+
self.spec_std,
|
336 |
+
self.sample_rate,
|
337 |
+
self.use_spec,
|
338 |
+
False,
|
339 |
+
self.extra_audio_masking,
|
340 |
+
neg_waveform,
|
341 |
+
neg_obs_sr,
|
342 |
+
self.audio_level,
|
343 |
+
self.audio_aug
|
344 |
+
)
|
345 |
+
|
346 |
+
if self.spec_transform is not None and spectrogram is not None:
|
347 |
+
spectrogram = self.spec_transform(spectrogram)
|
348 |
+
|
349 |
+
if self.audio_transform is not None:
|
350 |
+
waveform = self.audio_transform(waveform)
|
351 |
+
|
352 |
+
data_dict["audio"] = waveform
|
353 |
+
data_dict[AUDIO_MASK] = mask
|
354 |
+
data_dict[AUDIO_POS_MASK] = pos_mask
|
355 |
+
data_dict["audio_length"] = audio_length
|
356 |
+
data_dict["original_length"] = original_length
|
357 |
+
data_dict["total_length"] = total_length
|
358 |
+
if spectrogram is not None:
|
359 |
+
data_dict["spec"] = spectrogram
|
360 |
+
|
361 |
+
if mask.mean() < .04:
|
362 |
+
return None
|
363 |
+
|
364 |
+
if self.use_davenet_spec:
|
365 |
+
from data.DavenetUtilities import davenet_load_audio
|
366 |
+
audio_file = self.metadata["audio_file"].iloc[item]
|
367 |
+
spec, n_frames = davenet_load_audio(audio_file)
|
368 |
+
data_dict["davenet_spec"] = spec
|
369 |
+
|
370 |
+
if self.use_fnac_spec:
|
371 |
+
from featurizers.FNACAVL import load_spectrogram as fnac_load_spectrogram
|
372 |
+
audio_file = self.metadata["audio_file"].iloc[item]
|
373 |
+
data_dict["fnac_spec"] = fnac_load_spectrogram(audio_file, 3)
|
374 |
+
|
375 |
+
if self.use_audio_embed:
|
376 |
+
loaded = np.load(self.metadata["embed_file"].iloc[item])
|
377 |
+
data_dict["audio_emb"] = loaded["feat"]
|
378 |
+
data_dict["audio_length"] = loaded["audio_length"]
|
379 |
+
data_dict["total_length"] = loaded["total_length"]
|
380 |
+
data_dict["original_length"] = loaded["original_length"]
|
381 |
+
data_dict[AUDIO_MASK] = self.get_audio_mask(
|
382 |
+
data_dict["audio_length"],
|
383 |
+
data_dict["total_length"],
|
384 |
+
data_dict["audio_emb"].shape[-1]) \
|
385 |
+
.squeeze().to(torch.float32)
|
386 |
+
data_dict[AUDIO_POS_MASK] = data_dict[AUDIO_MASK].to(torch.float32)
|
387 |
+
|
388 |
+
if self.use_frames:
|
389 |
+
|
390 |
+
def get_frames(item):
|
391 |
+
file_group = self.metadata["frame_files"].iloc[item]
|
392 |
+
if self.n_frames is not None:
|
393 |
+
selected_frames = torch.randperm(len(file_group))[:self.n_frames]
|
394 |
+
file_group = [file_group[i] for i in selected_frames]
|
395 |
+
data_dict["metadata"]["frame_files"] = file_group
|
396 |
+
images = [Image.open(file).convert("RGB") for file in file_group]
|
397 |
+
|
398 |
+
if self.frame_transform is not None:
|
399 |
+
images = torch.cat([self.frame_transform(img).unsqueeze(0) for img in images], dim=0)
|
400 |
+
|
401 |
+
return images, file_group
|
402 |
+
|
403 |
+
no_mixup = 1.0 - (self.bg_mixup + self.quad_mixup + self.patch_mixup)
|
404 |
+
|
405 |
+
mixup_type = sample_choice(
|
406 |
+
["quad", "bg", "patch", None],
|
407 |
+
[self.quad_mixup, self.bg_mixup, self.patch_mixup, no_mixup]
|
408 |
+
)
|
409 |
+
|
410 |
+
if mixup_type == "quad":
|
411 |
+
indices = [item] + torch.randint(0, len(self), size=(3,)).numpy().tolist()
|
412 |
+
frames_and_files = [get_frames(i) for i in indices]
|
413 |
+
file_group = frames_and_files[0][1]
|
414 |
+
perm = torch.randperm(4)
|
415 |
+
all_frames = [F.interpolate(frames_and_files[i][0], scale_factor=0.5, mode="bilinear") for i in
|
416 |
+
perm]
|
417 |
+
b, c, h, w = all_frames[0].shape
|
418 |
+
indices = [indices[p] for p in perm]
|
419 |
+
masks = [(torch.ones(b, 1, h, w) if index == item else torch.zeros(b, 1, h, w)) for index in
|
420 |
+
indices]
|
421 |
+
|
422 |
+
data_dict[IMAGE_INPUT] = grid_frames(all_frames)
|
423 |
+
data_dict[IMAGE_MASK] = grid_frames(masks)
|
424 |
+
elif mixup_type == "bg":
|
425 |
+
neg_item = torch.randint(0, len(self), size=(1,)).item()
|
426 |
+
neg_frame, _ = get_frames(neg_item)
|
427 |
+
pos_frame, file_group = get_frames(item)
|
428 |
+
|
429 |
+
b, c, h, w = neg_frame.shape
|
430 |
+
neg_mask = torch.zeros(b, 1, h, w)
|
431 |
+
pos_mask = torch.ones(b, 1, h, w)
|
432 |
+
|
433 |
+
if torch.rand(1).item() > 0.5:
|
434 |
+
bg_frame = neg_frame
|
435 |
+
bg_mask = neg_mask
|
436 |
+
fg_frame = F.interpolate(pos_frame, scale_factor=0.5, mode="bilinear")
|
437 |
+
fg_mask = F.interpolate(pos_mask, scale_factor=0.5, mode="bilinear")
|
438 |
+
else:
|
439 |
+
bg_frame = pos_frame
|
440 |
+
bg_mask = pos_mask
|
441 |
+
fg_frame = F.interpolate(neg_frame, scale_factor=0.5, mode="bilinear")
|
442 |
+
fg_mask = F.interpolate(neg_mask, scale_factor=0.5, mode="bilinear")
|
443 |
+
|
444 |
+
start_h = torch.randint(0, h // 2, size=(1,))
|
445 |
+
start_w = torch.randint(0, w // 2, size=(1,))
|
446 |
+
bg_frame[:, :, start_h:start_h + fg_frame.shape[2], start_w:start_w + fg_frame.shape[3]] = fg_frame
|
447 |
+
bg_mask[:, :, start_h:start_h + fg_frame.shape[2], start_w:start_w + fg_frame.shape[3]] = fg_mask
|
448 |
+
|
449 |
+
data_dict["frames"] = bg_frame
|
450 |
+
data_dict["image_masks"] = bg_mask
|
451 |
+
|
452 |
+
elif mixup_type == "patch":
|
453 |
+
neg_item = torch.randint(0, len(self), size=(1,)).item()
|
454 |
+
neg_frame, _ = get_frames(neg_item)
|
455 |
+
pos_frame, file_group = get_frames(item)
|
456 |
+
frames, masks = create_mixed_image(pos_frame, neg_frame, self.patch_size)
|
457 |
+
data_dict["frames"] = frames
|
458 |
+
data_dict["image_masks"] = masks
|
459 |
+
|
460 |
+
elif mixup_type is None:
|
461 |
+
frames, file_group = get_frames(item)
|
462 |
+
|
463 |
+
data_dict["frames"] = frames
|
464 |
+
b, c, h, w = frames.shape
|
465 |
+
data_dict["image_masks"] = torch.ones(b, 1, h, w)
|
466 |
+
else:
|
467 |
+
raise ValueError(f"Unknown mixup type {mixup_type}")
|
468 |
+
|
469 |
+
if "original_length" in data_dict:
|
470 |
+
if self._expected_num_frames() == 1:
|
471 |
+
frame_nums = torch.tensor([0])
|
472 |
+
else:
|
473 |
+
frame_nums = torch.tensor([
|
474 |
+
int(f.split("/")[-1].split("_")[-1].split(".")[0]) for f in file_group])
|
475 |
+
|
476 |
+
data_dict["frame_nums"] = frame_nums
|
477 |
+
frame_fracs = ((frame_nums + .5) / (self._expected_num_frames()))
|
478 |
+
frame_position = (frame_fracs * data_dict["original_length"]) / data_dict["total_length"]
|
479 |
+
data_dict["frame_position"] = frame_position
|
480 |
+
|
481 |
+
if self.use_caption:
|
482 |
+
if "word" in self.metadata:
|
483 |
+
words = self.metadata["word"].iloc[item]
|
484 |
+
start = self.metadata["start"].iloc[item]
|
485 |
+
end = self.metadata["end"].iloc[item]
|
486 |
+
if isinstance(words, float):
|
487 |
+
words = [""]
|
488 |
+
start = [0.0]
|
489 |
+
end = [-1.0]
|
490 |
+
|
491 |
+
data_dict["caption"] = {
|
492 |
+
"words": words,
|
493 |
+
"start": start,
|
494 |
+
"end": end,
|
495 |
+
}
|
496 |
+
if "text" in self.metadata:
|
497 |
+
data_dict["text"] = self.metadata["text"].iloc[item]
|
498 |
+
|
499 |
+
if self.use_semseg:
|
500 |
+
semseg_path = join(self._semseg_root(), self.metadata["semseg_file"].iloc[item])
|
501 |
+
semseg = Image.open(semseg_path)
|
502 |
+
if self.label_transform is not None:
|
503 |
+
semseg = np.array(self.label_transform(semseg))
|
504 |
+
data_dict["semseg"] = semseg
|
505 |
+
data_dict["metadata"]["semseg_file"] = semseg_path
|
506 |
+
|
507 |
+
# if hasattr(self, "num_classes"):
|
508 |
+
# data_dict["num_pixels_per_class"] = F.one_hot(
|
509 |
+
# torch.tensor(semseg).to(torch.int64), self.num_classes() + 1).sum(dim=[0, 1])
|
510 |
+
|
511 |
+
return data_dict
|
512 |
+
|
513 |
+
def __getitem__(self, item):
|
514 |
+
try:
|
515 |
+
data_dict = self._base_get_item(item)
|
516 |
+
if self.use_hn:
|
517 |
+
indices = torch.cat([self.metadata["hn0"].iloc[item], self.metadata["hn1"].iloc[item]], dim=0)
|
518 |
+
neg_index = indices[torch.randint(0, indices.shape[0], (1,))]
|
519 |
+
negative_dict = self._base_get_item(neg_index)
|
520 |
+
data_dict["negatives"] = negative_dict
|
521 |
+
return data_dict
|
522 |
+
except (audioread.exceptions.NoBackendError, EOFError) as e:
|
523 |
+
# raise e
|
524 |
+
bad_path = self.metadata["audio_file"].iloc[item]
|
525 |
+
print(e)
|
526 |
+
print(f"Removing bad audio file {bad_path}")
|
527 |
+
# os.remove(bad_path)
|
528 |
+
return None
|
529 |
+
except ValueError as e:
|
530 |
+
# raise e
|
531 |
+
bad_path = self.metadata["audio_file"].iloc[item]
|
532 |
+
if "Input signal length=0" in str(e):
|
533 |
+
print(e)
|
534 |
+
print(f"Removing bad file {bad_path} due to input signal length=0")
|
535 |
+
# os.remove(bad_path)
|
536 |
+
return None
|
537 |
+
except OSError as e:
|
538 |
+
# raise e
|
539 |
+
bad_paths = self.metadata["frame_files"].iloc[item]
|
540 |
+
for bad_path in bad_paths:
|
541 |
+
print(e)
|
542 |
+
print(f"Removing bad frame file {bad_path}")
|
543 |
+
return None
|
544 |
+
except RuntimeError as e:
|
545 |
+
# raise e
|
546 |
+
bad_path = self.metadata["audio_file"].iloc[item]
|
547 |
+
print(e)
|
548 |
+
print(f"Removing bad audio file {bad_path}")
|
549 |
+
# os.remove(bad_path)
|
550 |
+
return None
|
551 |
+
|
552 |
+
|
553 |
+
class PlacesAudio(AVDataset):
|
554 |
+
|
555 |
+
def _load_info(self, split) -> pd.DataFrame:
|
556 |
+
df = pd.read_json(join(os.path.dirname(self._audio_root()), "metadata", f"{split}.json"))
|
557 |
+
df["id"] = df["data"].apply(lambda d: d["wav"][5:-4])
|
558 |
+
|
559 |
+
if self.use_caption:
|
560 |
+
if split == "train":
|
561 |
+
word_df = pd.read_json(
|
562 |
+
join(os.path.dirname(self._audio_root()), "metadata", f"word-alignment-{split}.json")
|
563 |
+
)
|
564 |
+
else:
|
565 |
+
word_df = pd.read_csv(
|
566 |
+
join(os.path.dirname(self._audio_root()), "metadata", f"word-alignment-{split}.csv")) \
|
567 |
+
.groupby("id").aggregate(lambda g: list(g)).reset_index().drop("Unnamed: 0", axis=1)
|
568 |
+
df = pd.merge(df, word_df, on="id", how="outer")
|
569 |
+
return df
|
570 |
+
|
571 |
+
def _missing_threshold(self) -> float:
|
572 |
+
# return 0.0
|
573 |
+
return 0.97 # TODO fix
|
574 |
+
|
575 |
+
def _expected_num_frames(self):
|
576 |
+
return 1
|
577 |
+
|
578 |
+
def default_target_length(self) -> int:
|
579 |
+
return 20
|
580 |
+
|
581 |
+
def _frame_root(self) -> str:
|
582 |
+
return join(os.path.dirname(self.root), "places_subset")
|
583 |
+
|
584 |
+
def _audio_root(self) -> str:
|
585 |
+
return join(self.root, "wavs")
|
586 |
+
|
587 |
+
def _embed_root(self) -> str:
|
588 |
+
return join(self.root, "embedding", self.audio_embed_model)
|
589 |
+
|
590 |
+
def _dataset_folder(self) -> str:
|
591 |
+
return "PlacesAudio_400k_distro"
|
592 |
+
|
593 |
+
def _get_audio_file(self, row) -> str:
|
594 |
+
return join(self._audio_root(), row["id"] + ".wav")
|
595 |
+
|
596 |
+
def _get_frame_files(self, row) -> List[str]:
|
597 |
+
return [join(self._frame_root(), row["data"]["image"])]
|
598 |
+
|
599 |
+
def _get_embed_file(self, row) -> str:
|
600 |
+
return join(self._embed_root(), row["id"] + ".npz")
|
601 |
+
|
602 |
+
|
603 |
+
class AudioSet(AVDataset):
|
604 |
+
def _expected_num_frames(self):
|
605 |
+
return 10
|
606 |
+
|
607 |
+
def default_target_length(self) -> int:
|
608 |
+
return 20
|
609 |
+
|
610 |
+
def _dataset_folder(self) -> str:
|
611 |
+
return "audioset-raw"
|
612 |
+
|
613 |
+
def _missing_threshold(self) -> float:
|
614 |
+
if self.split == "val" or self.split == "test":
|
615 |
+
return 0.02
|
616 |
+
else:
|
617 |
+
return 0.17
|
618 |
+
|
619 |
+
def train_seg_file(self):
|
620 |
+
return "unbalanced_train_segments.csv"
|
621 |
+
|
622 |
+
def _load_info(self, split) -> pd.DataFrame:
|
623 |
+
if split == "train":
|
624 |
+
df = pd.read_csv(join(self.root, "metadata", self.train_seg_file()))
|
625 |
+
elif split == "val" or split == "test":
|
626 |
+
df = pd.read_csv(join(self.root, "metadata", "eval_segments_subset.csv"))
|
627 |
+
else:
|
628 |
+
raise ValueError(f"Unknown split {split}")
|
629 |
+
|
630 |
+
labels = pd.read_csv(join(self.root, "metadata", "class_labels_indices.csv"))
|
631 |
+
mid_to_index = dict(zip(labels["mid"], labels["index"]))
|
632 |
+
df["tags"] = df["positive_labels"].apply(lambda l: [mid_to_index[e] for e in l.strip('"').split(",")])
|
633 |
+
|
634 |
+
self.num_tags = max(*[i for k, i in mid_to_index.items()]) + 1
|
635 |
+
df["id"] = df.apply(lambda r: f"{r.YTID}_{r.start_seconds}_{r.end_seconds}", axis=1)
|
636 |
+
return df
|
637 |
+
|
638 |
+
def _frame_root(self) -> str:
|
639 |
+
return join(self.root, "frames")
|
640 |
+
|
641 |
+
def _audio_root(self) -> str:
|
642 |
+
return join(self.root, "audio")
|
643 |
+
|
644 |
+
def _all_frame_files(self) -> Set[str]:
|
645 |
+
frame_files = set()
|
646 |
+
|
647 |
+
for entry in os.scandir(self._frame_root()):
|
648 |
+
if entry.is_file():
|
649 |
+
frame_files.add(entry.path)
|
650 |
+
elif entry.is_dir():
|
651 |
+
for subentry in os.scandir(entry.path):
|
652 |
+
if subentry.is_file():
|
653 |
+
frame_files.add(subentry.path)
|
654 |
+
|
655 |
+
return frame_files
|
656 |
+
|
657 |
+
def _all_audio_files(self) -> Set[str]:
|
658 |
+
return set(entry.path for entry in os.scandir(self._audio_root()) if entry.is_file())
|
659 |
+
|
660 |
+
def _all_embed_files(self) -> Set[str]:
|
661 |
+
return set(entry.path for entry in os.scandir(self._embed_root()) if entry.is_file())
|
662 |
+
|
663 |
+
def _embed_root(self) -> str:
|
664 |
+
return join(self.root, "embedding", self.audio_embed_model)
|
665 |
+
|
666 |
+
def prefix(self):
|
667 |
+
return ""
|
668 |
+
|
669 |
+
def _get_audio_file(self, row) -> str:
|
670 |
+
return f"{self.root}/audio/{self.prefix()}{row.id}.mp3"
|
671 |
+
|
672 |
+
def _get_frame_files(self, row) -> List[str]:
|
673 |
+
return [f"{self.root}/frames/frame_{fn}/{self.prefix()}{row.id}.jpg" for fn in range(10)]
|
674 |
+
|
675 |
+
def _get_embed_file(self, row) -> str:
|
676 |
+
return f"{self.root}/embedding/{self.audio_embed_model}/{self.prefix()}{row.id}.npz"
|
677 |
+
|
678 |
+
|
679 |
+
class AudioSetEval(AudioSet):
|
680 |
+
|
681 |
+
def _dataset_folder(self) -> str:
|
682 |
+
return "audioset-eval"
|
683 |
+
|
684 |
+
def _get_frame_files(self, row) -> List[str]:
|
685 |
+
base_path = f"{self.root}/frames/{self.prefix()}{row.id}_"
|
686 |
+
return [base_path + f"{fn}.jpg" for fn in range(10)]
|
687 |
+
|
688 |
+
def prefix(self):
|
689 |
+
return ""
|
690 |
+
|
691 |
+
|
692 |
+
class ADE20K(AVDataset):
|
693 |
+
|
694 |
+
def _split_name(self, split):
|
695 |
+
if split == "val":
|
696 |
+
return "validation"
|
697 |
+
elif split == "train":
|
698 |
+
return "training"
|
699 |
+
else:
|
700 |
+
raise ValueError(f"Unknown split name {split}")
|
701 |
+
|
702 |
+
def _load_info(self, split) -> pd.DataFrame:
|
703 |
+
df = pd.read_json(join(self.root, "metadata_with_caption_dedup.json"))
|
704 |
+
df["id"] = df["image"]
|
705 |
+
df = df[df["image"].apply(lambda f: f.split("/")[0] == split)]
|
706 |
+
|
707 |
+
if self.use_caption:
|
708 |
+
df["word"] = df["caption"].apply(lambda c: c["words"])
|
709 |
+
df["start"] = df["caption"].apply(lambda c: c["start"])
|
710 |
+
df["end"] = df["caption"].apply(lambda c: c["end"])
|
711 |
+
df["text"] = df["word"].apply(lambda l: " ".join(l))
|
712 |
+
return df
|
713 |
+
|
714 |
+
def _missing_threshold(self) -> float:
|
715 |
+
return 0.03
|
716 |
+
|
717 |
+
def _expected_num_frames(self):
|
718 |
+
return 1
|
719 |
+
|
720 |
+
def default_target_length(self) -> int:
|
721 |
+
return 20
|
722 |
+
|
723 |
+
def _dataset_folder(self) -> str:
|
724 |
+
return "ADE20K"
|
725 |
+
|
726 |
+
def _frame_root(self) -> str:
|
727 |
+
return join(self.root, "frames")
|
728 |
+
|
729 |
+
def _audio_root(self) -> str:
|
730 |
+
return join(self.root, "audio")
|
731 |
+
|
732 |
+
def _semseg_root(self) -> str:
|
733 |
+
return join(self.root, "annotations")
|
734 |
+
|
735 |
+
def _embed_root(self) -> str:
|
736 |
+
return join(self.root, "embedding", self.audio_embed_model)
|
737 |
+
|
738 |
+
def _get_audio_file(self, row) -> str:
|
739 |
+
return join(self._audio_root(), row["audio"])
|
740 |
+
|
741 |
+
def _get_frame_files(self, row) -> List[str]:
|
742 |
+
return [join(self._frame_root(), row["image"])]
|
743 |
+
|
744 |
+
def _get_semseg_file(self, row) -> str:
|
745 |
+
return join(self._semseg_root(), row["seg"])
|
746 |
+
|
747 |
+
def _get_embed_file(self, row) -> str:
|
748 |
+
return join(self._embed_root(), row["image"].replace(".jpg", ".npz"))
|
749 |
+
|
750 |
+
def num_classes(self):
|
751 |
+
return 3662
|
752 |
+
|
753 |
+
|
754 |
+
class ADE20KPromptedBase(AVDataset):
|
755 |
+
|
756 |
+
def _expected_num_frames(self):
|
757 |
+
return 1
|
758 |
+
|
759 |
+
def default_target_length(self) -> int:
|
760 |
+
return 20
|
761 |
+
|
762 |
+
def _frame_root(self) -> str:
|
763 |
+
return join(self.root, "frames")
|
764 |
+
|
765 |
+
def _audio_root(self) -> str:
|
766 |
+
return join(self.root, "audio")
|
767 |
+
|
768 |
+
def _semseg_root(self) -> str:
|
769 |
+
return join(self.root, "annotations")
|
770 |
+
|
771 |
+
def _embed_root(self) -> str:
|
772 |
+
return join(self.root, "embedding", self.audio_embed_model)
|
773 |
+
|
774 |
+
def _get_frame_files(self, row) -> List[str]:
|
775 |
+
return [join(self._frame_root(), row["image_location"])]
|
776 |
+
|
777 |
+
def _get_semseg_file(self, row) -> str:
|
778 |
+
return join(self._semseg_root(), row["image_location"].replace(".jpg", "_seg.png"))
|
779 |
+
|
780 |
+
def _get_embed_file(self, row) -> str:
|
781 |
+
return join(self._embed_root(), row["image_location"].replace(".jpg", ".npz"))
|
782 |
+
|
783 |
+
def num_classes(self):
|
784 |
+
return 3662
|
785 |
+
|
786 |
+
def _missing_threshold(self) -> float:
|
787 |
+
return 0.0
|
788 |
+
|
789 |
+
|
790 |
+
class ADE20KSpeechPrompted(ADE20KPromptedBase):
|
791 |
+
|
792 |
+
def _get_audio_file(self, row) -> str:
|
793 |
+
return join(self._audio_root(), row["speech_prompt_file"].split("/")[-1])
|
794 |
+
|
795 |
+
def _dataset_folder(self) -> str:
|
796 |
+
return "ADE20KSpeechPrompted"
|
797 |
+
|
798 |
+
def _audio_root(self) -> str:
|
799 |
+
# return join(self.root, "audio-noise-10") # TODO Remove
|
800 |
+
return join(self.root, "audio") # TODO Remove
|
801 |
+
|
802 |
+
def _load_info(self, split) -> pd.DataFrame:
|
803 |
+
df = pd.read_csv(join(self.root, "prompted_segmentation.csv"))
|
804 |
+
df = df[df["speech_prompt_file"].apply(lambda s: isinstance(s, str))]
|
805 |
+
df = df[df["ade_class_id"].apply(lambda id: id != 0)]
|
806 |
+
df["id"] = df["image_location"]
|
807 |
+
return df
|
808 |
+
|
809 |
+
|
810 |
+
class ADE20KSoundPrompted(ADE20KPromptedBase):
|
811 |
+
|
812 |
+
def _get_audio_file(self, row) -> str:
|
813 |
+
return join(self._audio_root(), row["vggsound_file"].split("/")[-1])
|
814 |
+
|
815 |
+
def _dataset_folder(self) -> str:
|
816 |
+
return "ADE20KSoundPrompted"
|
817 |
+
|
818 |
+
def _load_info(self, split) -> pd.DataFrame:
|
819 |
+
df = pd.read_csv(join(self.root, "prompted_segmentation.csv"))
|
820 |
+
df = df[df["vggsound_file"].apply(lambda s: isinstance(s, str))]
|
821 |
+
df = df[df["ade_class_id"].apply(lambda id: id != 0)]
|
822 |
+
df["id"] = df["image_location"]
|
823 |
+
return df
|
824 |
+
|
825 |
+
|
826 |
+
class PlacesAndAudioSet(Dataset):
|
827 |
+
|
828 |
+
def __init__(self, **kwargs):
|
829 |
+
self.ds1 = PlacesAudio(**kwargs, n_frames=1)
|
830 |
+
self.ds2 = AudioSet(**kwargs, n_frames=1)
|
831 |
+
|
832 |
+
def __len__(self):
|
833 |
+
return len(self.ds1)
|
834 |
+
|
835 |
+
def __getitem__(self, item):
|
836 |
+
if torch.rand(1).item() > .5:
|
837 |
+
d = self.ds2[torch.randint(0, len(self.ds2) - 1, size=(1,)).item()]
|
838 |
+
if d is not None:
|
839 |
+
d["source"] = 1
|
840 |
+
else:
|
841 |
+
d = self.ds1[item]
|
842 |
+
if d is not None:
|
843 |
+
d["source"] = 0
|
844 |
+
return d
|
845 |
+
|
846 |
+
|
847 |
+
class AVDataModule(pl.LightningDataModule):
|
848 |
+
def __init__(self,
|
849 |
+
dataset_name,
|
850 |
+
load_size,
|
851 |
+
image_aug,
|
852 |
+
audio_aug,
|
853 |
+
extra_audio_masking,
|
854 |
+
audio_model_type,
|
855 |
+
pytorch_data_dir,
|
856 |
+
use_cached_embs,
|
857 |
+
batch_size,
|
858 |
+
num_workers,
|
859 |
+
audio_level,
|
860 |
+
neg_audio,
|
861 |
+
data_for_plotting,
|
862 |
+
use_original_val_set,
|
863 |
+
use_extra_val_sets,
|
864 |
+
quad_mixup,
|
865 |
+
bg_mixup,
|
866 |
+
patch_mixup,
|
867 |
+
patch_size,
|
868 |
+
**kwargs):
|
869 |
+
|
870 |
+
super().__init__()
|
871 |
+
self.dataset_name = dataset_name
|
872 |
+
self.load_size = load_size
|
873 |
+
self.image_aug = image_aug
|
874 |
+
self.audio_aug = audio_aug
|
875 |
+
self.extra_audio_masking = extra_audio_masking
|
876 |
+
self.audio_model_type = audio_model_type
|
877 |
+
self.pytorch_data_dir = pytorch_data_dir
|
878 |
+
self.use_cached_embs = use_cached_embs
|
879 |
+
self.batch_size = batch_size
|
880 |
+
self.num_workers = num_workers
|
881 |
+
self.data_for_plotting = data_for_plotting
|
882 |
+
self.audio_level = audio_level
|
883 |
+
self.neg_audio = neg_audio
|
884 |
+
|
885 |
+
self.quad_mixup = quad_mixup
|
886 |
+
self.bg_mixup = bg_mixup
|
887 |
+
self.patch_mixup = patch_mixup
|
888 |
+
self.patch_size = patch_size
|
889 |
+
|
890 |
+
self.loader_args = dict(
|
891 |
+
num_workers=self.num_workers,
|
892 |
+
batch_size=self.batch_size,
|
893 |
+
)
|
894 |
+
self.save_hyperparameters()
|
895 |
+
self.extra_args = kwargs
|
896 |
+
|
897 |
+
self.use_original_val_set = use_original_val_set
|
898 |
+
self.use_extra_val_sets = use_extra_val_sets
|
899 |
+
|
900 |
+
def maybe_unpack(self, remove_source):
|
901 |
+
targets = [
|
902 |
+
(
|
903 |
+
join(self.pytorch_data_dir, "audioset-subset", "frame_archives"),
|
904 |
+
join(self.pytorch_data_dir, "audioset-subset", "frames"),
|
905 |
+
1
|
906 |
+
),
|
907 |
+
(
|
908 |
+
join(self.pytorch_data_dir, "audioset-raw", "frame_archives"),
|
909 |
+
join(self.pytorch_data_dir, "audioset-raw", "frames"),
|
910 |
+
4
|
911 |
+
),
|
912 |
+
(
|
913 |
+
join(self.pytorch_data_dir, "audioset-raw", "audio_archives"),
|
914 |
+
join(self.pytorch_data_dir, "audioset-raw", "audio"),
|
915 |
+
1
|
916 |
+
),
|
917 |
+
|
918 |
+
]
|
919 |
+
|
920 |
+
for (archive_dir, target_dir, n_parts) in targets:
|
921 |
+
if not os.path.exists(target_dir) and os.path.exists(archive_dir):
|
922 |
+
print(f"Could not find {target_dir}, attempting to unpack archives")
|
923 |
+
if os.path.exists(archive_dir):
|
924 |
+
untar_all(archive_dir, target_dir, remove_source)
|
925 |
+
else:
|
926 |
+
raise RuntimeError(f"Could not find archive folder: {archive_dir}")
|
927 |
+
|
928 |
+
def get_dataset_by_name(self, name, stage, data_for_plotting, n_frames=None):
|
929 |
+
|
930 |
+
if name == "vggss":
|
931 |
+
resize_op = T.Resize((self.load_size, self.load_size), Image.BILINEAR)
|
932 |
+
else:
|
933 |
+
resize_op = T.Resize(self.load_size, Image.BILINEAR)
|
934 |
+
|
935 |
+
img_transform = T.Compose([
|
936 |
+
resize_op,
|
937 |
+
T.CenterCrop(self.load_size),
|
938 |
+
T.ToTensor(),
|
939 |
+
norm])
|
940 |
+
|
941 |
+
if self.image_aug:
|
942 |
+
train_img_transform = T.Compose([
|
943 |
+
T.RandomResizedCrop(self.load_size),
|
944 |
+
T.RandomHorizontalFlip(),
|
945 |
+
T.ColorJitter(.2, .2, .2, .2),
|
946 |
+
T.RandomGrayscale(),
|
947 |
+
T.ToTensor(),
|
948 |
+
norm])
|
949 |
+
val_img_transform = img_transform
|
950 |
+
else:
|
951 |
+
train_img_transform = img_transform
|
952 |
+
val_img_transform = img_transform
|
953 |
+
|
954 |
+
if self.audio_aug:
|
955 |
+
train_audio_aug = True
|
956 |
+
val_audio_aug = False
|
957 |
+
else:
|
958 |
+
train_audio_aug = False
|
959 |
+
val_audio_aug = False
|
960 |
+
|
961 |
+
if self.audio_model_type == "hubert":
|
962 |
+
from featurizers.Hubert import HubertAudioTransform
|
963 |
+
audio_transform = HubertAudioTransform()
|
964 |
+
else:
|
965 |
+
audio_transform = None
|
966 |
+
|
967 |
+
if self.audio_model_type == "passt":
|
968 |
+
sample_rate = 32000
|
969 |
+
else:
|
970 |
+
sample_rate = 16000
|
971 |
+
|
972 |
+
if not self.use_cached_embs:
|
973 |
+
if self.audio_model_type == "hubert":
|
974 |
+
self.extra_args["use_audio"] = True
|
975 |
+
elif self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
|
976 |
+
self.extra_args["use_spec"] = True
|
977 |
+
elif self.audio_model_type == "davenet":
|
978 |
+
self.extra_args["use_audio"] = True
|
979 |
+
self.extra_args["use_davenet_spec"] = True
|
980 |
+
elif self.audio_model_type == "fnac":
|
981 |
+
self.extra_args["use_audio"] = True
|
982 |
+
self.extra_args["use_fnac_spec"] = True
|
983 |
+
else:
|
984 |
+
raise ValueError(f"Unknown audio model type {self.audio_model_type}")
|
985 |
+
|
986 |
+
if self.audio_model_type == "cavmae" or self.audio_model_type == "cavmae-mixed":
|
987 |
+
self.extra_args["spec_mean"] = -5.081
|
988 |
+
self.extra_args["spec_std"] = 4.4849
|
989 |
+
elif self.audio_model_type == "imagebind":
|
990 |
+
self.extra_args["spec_mean"] = -4.268
|
991 |
+
self.extra_args["spec_std"] = 9.138
|
992 |
+
|
993 |
+
# if self.audio_model_type in {"audiomae", "audiomae-finetune", "cavmae"} \
|
994 |
+
# and "override_target_length" not in self.extra_args:
|
995 |
+
if "override_target_length" not in self.extra_args:
|
996 |
+
self.extra_args["override_target_length"] = 10
|
997 |
+
|
998 |
+
data_args = dict(
|
999 |
+
root=self.pytorch_data_dir,
|
1000 |
+
use_frames=True,
|
1001 |
+
audio_transform=audio_transform,
|
1002 |
+
sample_rate=sample_rate,
|
1003 |
+
audio_level=self.audio_level,
|
1004 |
+
**self.extra_args
|
1005 |
+
)
|
1006 |
+
|
1007 |
+
if n_frames is not None:
|
1008 |
+
data_args["n_frames"] = n_frames
|
1009 |
+
|
1010 |
+
train_args = dict(
|
1011 |
+
frame_transform=train_img_transform,
|
1012 |
+
extra_audio_masking=self.extra_audio_masking,
|
1013 |
+
neg_audio=self.neg_audio,
|
1014 |
+
quad_mixup=self.quad_mixup,
|
1015 |
+
bg_mixup=self.bg_mixup,
|
1016 |
+
patch_mixup=self.patch_mixup,
|
1017 |
+
patch_size=self.patch_size,
|
1018 |
+
audio_aug=train_audio_aug
|
1019 |
+
)
|
1020 |
+
val_args = dict(
|
1021 |
+
frame_transform=val_img_transform,
|
1022 |
+
audio_aug=val_audio_aug
|
1023 |
+
)
|
1024 |
+
|
1025 |
+
if data_for_plotting:
|
1026 |
+
val_args["use_audio"] = True
|
1027 |
+
val_args["use_spec"] = True
|
1028 |
+
|
1029 |
+
if "ade" in name:
|
1030 |
+
label_transform = T.Compose([
|
1031 |
+
T.Resize(self.load_size, Image.NEAREST),
|
1032 |
+
T.CenterCrop(self.load_size),
|
1033 |
+
prep_ade_label
|
1034 |
+
])
|
1035 |
+
else:
|
1036 |
+
label_transform = T.Compose([
|
1037 |
+
T.Resize(self.load_size, Image.NEAREST),
|
1038 |
+
T.CenterCrop(self.load_size)
|
1039 |
+
])
|
1040 |
+
|
1041 |
+
val_args["use_audio"] = True
|
1042 |
+
val_args["label_transform"] = label_transform
|
1043 |
+
|
1044 |
+
if name == "places-audio":
|
1045 |
+
dataset_constructor = PlacesAudio
|
1046 |
+
elif name == "mixed-full":
|
1047 |
+
dataset_constructor = PlacesAndAudioSet
|
1048 |
+
elif name == "audio-set-full":
|
1049 |
+
dataset_constructor = AudioSet
|
1050 |
+
elif name == "audio-set-eval":
|
1051 |
+
dataset_constructor = AudioSetEval
|
1052 |
+
elif name == "ade":
|
1053 |
+
val_args["use_semseg"] = True
|
1054 |
+
dataset_constructor = ADE20K
|
1055 |
+
elif name == "ade-speech-prompted":
|
1056 |
+
val_args["use_semseg"] = True
|
1057 |
+
dataset_constructor = ADE20KSpeechPrompted
|
1058 |
+
elif name == "ade-sound-prompted":
|
1059 |
+
val_args["use_semseg"] = True
|
1060 |
+
dataset_constructor = ADE20KSoundPrompted
|
1061 |
+
else:
|
1062 |
+
raise ValueError(f"Unknown dataset name {name}")
|
1063 |
+
|
1064 |
+
data_args["use_audio_embed"] = self.use_cached_embs
|
1065 |
+
data_args["audio_embed_model"] = self.audio_model_type
|
1066 |
+
|
1067 |
+
if stage == "full":
|
1068 |
+
val_dataset = dataset_constructor(split="val", **{**data_args, **val_args})
|
1069 |
+
train_dataset = dataset_constructor(split="train", **{**data_args, **val_args})
|
1070 |
+
return ConcatDataset([train_dataset, val_dataset])
|
1071 |
+
elif stage == "fit":
|
1072 |
+
return dataset_constructor(split="train", **{**data_args, **train_args})
|
1073 |
+
elif stage == "validate":
|
1074 |
+
return dataset_constructor(split="val", **{**data_args, **val_args})
|
1075 |
+
else:
|
1076 |
+
raise ValueError(f"Unknown stage: {stage}")
|
1077 |
+
|
1078 |
+
def _maybe_subset(self, dataset, length):
|
1079 |
+
if len(dataset) > length and self.dataset_name not in {"ade-sound-prompted", "ade-speech-prompted", "vggss"}:
|
1080 |
+
print("Using a subset of validation data")
|
1081 |
+
return Subset(dataset, generate_subset(len(dataset), length))
|
1082 |
+
else:
|
1083 |
+
print("Not using val subset")
|
1084 |
+
return dataset
|
1085 |
+
|
1086 |
+
def _make_val_datasets(self):
|
1087 |
+
val_sets = []
|
1088 |
+
if self.use_original_val_set:
|
1089 |
+
val_sets.append(self._maybe_subset(self.get_dataset_by_name(
|
1090 |
+
self.dataset_name, "validate", self.data_for_plotting), 1000))
|
1091 |
+
|
1092 |
+
if self.use_extra_val_sets:
|
1093 |
+
val_sets.append(self._maybe_subset(self.get_dataset_by_name(
|
1094 |
+
"places-audio", "validate", self.data_for_plotting), 1000))
|
1095 |
+
val_sets.append(self._maybe_subset(self.get_dataset_by_name(
|
1096 |
+
"audio-set-eval", "validate", False, n_frames=1), 1000))
|
1097 |
+
val_sets.append(self.get_dataset_by_name(
|
1098 |
+
"ade-speech-prompted", "validate", True))
|
1099 |
+
val_sets.append(self.get_dataset_by_name(
|
1100 |
+
"ade-sound-prompted", "validate", self.data_for_plotting))
|
1101 |
+
|
1102 |
+
return val_sets
|
1103 |
+
|
1104 |
+
def setup(self, stage: str):
|
1105 |
+
if stage == "full":
|
1106 |
+
self.full_dataset = self.get_dataset_by_name(self.dataset_name, stage, self.data_for_plotting)
|
1107 |
+
elif stage == "fit":
|
1108 |
+
self.train_dataset = self.get_dataset_by_name(self.dataset_name, stage, self.data_for_plotting)
|
1109 |
+
self.val_datasets = self._make_val_datasets()
|
1110 |
+
elif stage == "validate":
|
1111 |
+
self.val_datasets = self._make_val_datasets()
|
1112 |
+
else:
|
1113 |
+
raise ValueError(f"Unknown stage: {stage}")
|
1114 |
+
|
1115 |
+
def train_dataloader(self):
|
1116 |
+
return DataLoader(self.train_dataset, shuffle=True, **self.loader_args, collate_fn=custom_coallate)
|
1117 |
+
|
1118 |
+
def subsampled_train_dataloader(self, k=5000):
|
1119 |
+
if len(self.train_dataset) > k:
|
1120 |
+
ds = Subset(self.train_dataset, generate_subset(len(self.train_dataset), k))
|
1121 |
+
else:
|
1122 |
+
ds = self.train_dataset
|
1123 |
+
|
1124 |
+
return DataLoader(ds, shuffle=True, **self.loader_args, collate_fn=custom_coallate)
|
1125 |
+
|
1126 |
+
def val_dataloader(self):
|
1127 |
+
return [
|
1128 |
+
DataLoader(dataset, shuffle=False, **self.loader_args, collate_fn=custom_coallate)
|
1129 |
+
for dataset in self.val_datasets
|
1130 |
+
]
|
1131 |
+
|
1132 |
+
def full_dataloader(self):
|
1133 |
+
return DataLoader(self.full_dataset, shuffle=False, **self.loader_args, collate_fn=custom_coallate)
|
1134 |
+
|
1135 |
+
|
1136 |
+
def generate_subset(n, batch, seed=0):
|
1137 |
+
np.random.seed(seed)
|
1138 |
+
return np.random.permutation(n)[:batch]
|
1139 |
+
|
1140 |
+
|
1141 |
+
def prep_ade_label(img):
|
1142 |
+
seg = np.array(img)
|
1143 |
+
class_labels = (seg[:, :, 0] / 10).astype(np.int32) * 256 + (seg[:, :, 1].astype(np.int32))
|
1144 |
+
return class_labels
|
1145 |
+
|
1146 |
+
|
1147 |
+
def maybe_replace(e, not_none):
|
1148 |
+
if e is not None:
|
1149 |
+
return e
|
1150 |
+
else:
|
1151 |
+
print("Warning found a None in the dataset indicitive of a loading failure, replacing it with another item")
|
1152 |
+
return not_none[0]
|
1153 |
+
|
1154 |
+
|
1155 |
+
empty_caption = {
|
1156 |
+
"words": [],
|
1157 |
+
"start": [],
|
1158 |
+
"end": [],
|
1159 |
+
}
|
1160 |
+
|
1161 |
+
|
1162 |
+
def custom_coallate(l):
|
1163 |
+
if l is None:
|
1164 |
+
return l
|
1165 |
+
|
1166 |
+
not_none = [e for e in l if e is not None]
|
1167 |
+
assert len(not_none) > 0
|
1168 |
+
|
1169 |
+
l = [maybe_replace(e, not_none) for e in l]
|
1170 |
+
|
1171 |
+
to_merge = {}
|
1172 |
+
|
1173 |
+
def pop_or_default(dict, k, default):
|
1174 |
+
if k in dict:
|
1175 |
+
return dict.pop(k)
|
1176 |
+
else:
|
1177 |
+
print(f"WARNING: Could not find {k}, using {default}")
|
1178 |
+
return default
|
1179 |
+
|
1180 |
+
if "caption" in l[0]:
|
1181 |
+
to_merge["caption"] = [pop_or_default(l[i], "caption", empty_caption) for i in range(len(l))]
|
1182 |
+
|
1183 |
+
if "text" in l[0]:
|
1184 |
+
to_merge["text"] = [pop_or_default(l[i], "text", "") for i in range(len(l))]
|
1185 |
+
|
1186 |
+
result = default_collate(l)
|
1187 |
+
|
1188 |
+
return {**result, **to_merge}
|
1189 |
+
|
1190 |
+
|
1191 |
+
if __name__ == "__main__":
|
1192 |
+
|
1193 |
+
from featurizers.Hubert import HubertAudioTransform
|
1194 |
+
|
1195 |
+
pytorch_data_dir = "/pytorch-data"
|
1196 |
+
dataset_constructor = PlacesAudio
|
1197 |
+
split = "val"
|
1198 |
+
|
1199 |
+
img_transform = T.Compose([
|
1200 |
+
T.Resize(224, Image.BILINEAR),
|
1201 |
+
T.CenterCrop(224),
|
1202 |
+
T.ToTensor(),
|
1203 |
+
norm])
|
1204 |
+
|
1205 |
+
video_transform = T.Compose([
|
1206 |
+
T.Resize(224, Image.BILINEAR),
|
1207 |
+
T.CenterCrop(224),
|
1208 |
+
norm])
|
1209 |
+
|
1210 |
+
label_transform = T.Compose([
|
1211 |
+
T.Resize(224, Image.NEAREST),
|
1212 |
+
T.CenterCrop(224)
|
1213 |
+
])
|
1214 |
+
|
1215 |
+
audio_transform = HubertAudioTransform()
|
1216 |
+
|
1217 |
+
data_args = dict(
|
1218 |
+
root=pytorch_data_dir,
|
1219 |
+
frame_transform=img_transform,
|
1220 |
+
use_frames=True,
|
1221 |
+
use_spec=True,
|
1222 |
+
use_audio=True,
|
1223 |
+
use_caption=False,
|
1224 |
+
use_semseg=False,
|
1225 |
+
label_transform=label_transform,
|
1226 |
+
audio_transform=audio_transform,
|
1227 |
+
use_audio_embed=False,
|
1228 |
+
audio_embed_model="audiomae",
|
1229 |
+
extra_audio_masking=False,
|
1230 |
+
neg_audio=False,
|
1231 |
+
override_target_length=10,
|
1232 |
+
audio_level=False,
|
1233 |
+
quad_mixup=.3,
|
1234 |
+
patch_mixup=.3,
|
1235 |
+
bg_mixup=.3,
|
1236 |
+
)
|
1237 |
+
|
1238 |
+
|
1239 |
+
def return_datasets(dataset_constructor, split):
|
1240 |
+
dataset = dataset_constructor(split=split, **data_args)
|
1241 |
+
return dataset
|
1242 |
+
|
1243 |
+
|
1244 |
+
train_ds = return_datasets(dataset_constructor, split)
|
1245 |
+
|
1246 |
+
print(len(train_ds))
|
1247 |
+
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, num_workers=36, collate_fn=custom_coallate)
|
1248 |
+
for batch in tqdm(train_loader):
|
1249 |
+
pass
|
DenseAV/denseav/data/__init__.py
ADDED
File without changes
|
DenseAV/denseav/data/make_tarballs.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import tarfile
|
4 |
+
from glob import glob
|
5 |
+
from io import BytesIO
|
6 |
+
from os.path import join
|
7 |
+
|
8 |
+
from torch.utils.data import Dataset, DataLoader
|
9 |
+
from tqdm import tqdm
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
from denseav.shared import batch
|
13 |
+
|
14 |
+
import tempfile
|
15 |
+
import shutil
|
16 |
+
|
17 |
+
|
18 |
+
class Tarballer(Dataset):
|
19 |
+
|
20 |
+
def __init__(self, source, target, n):
|
21 |
+
source_path = Path(source)
|
22 |
+
self.frames = [f.relative_to(source_path) for f in source_path.rglob('*') if f.is_file()]
|
23 |
+
assert (len(self.frames) > 0)
|
24 |
+
self.source = source
|
25 |
+
self.target_dir = target
|
26 |
+
self.batched = list(batch(self.frames, n))
|
27 |
+
os.makedirs(self.target_dir, exist_ok=True)
|
28 |
+
|
29 |
+
def __len__(self):
|
30 |
+
return len(self.batched)
|
31 |
+
|
32 |
+
def __getitem__(self, item):
|
33 |
+
with tarfile.open(join(self.target_dir, f"{item}.tar"), "w") as tar:
|
34 |
+
for relpath in self.batched[item]:
|
35 |
+
abs_path = os.path.join(self.source, str(relpath)) # Convert to string here
|
36 |
+
with open(abs_path, "rb") as file:
|
37 |
+
file_content = file.read()
|
38 |
+
info = tarfile.TarInfo(name=str(relpath)) # Convert to string here
|
39 |
+
info.size = len(file_content)
|
40 |
+
tar.addfile(info, fileobj=BytesIO(file_content))
|
41 |
+
|
42 |
+
return 0
|
43 |
+
|
44 |
+
|
45 |
+
class UnTarballer:
|
46 |
+
|
47 |
+
def __init__(self, archive_dir, target_dir, remove_source=False):
|
48 |
+
self.tarballs = sorted(glob(join(archive_dir, "*.tar")))
|
49 |
+
self.target_dir = target_dir
|
50 |
+
self.remove_source = remove_source # New flag to determine if source tarball should be removed
|
51 |
+
os.makedirs(self.target_dir, exist_ok=True)
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.tarballs)
|
55 |
+
|
56 |
+
def __getitem__(self, item):
|
57 |
+
with tarfile.open(self.tarballs[item], "r") as tar:
|
58 |
+
# Create a unique temporary directory inside the target directory
|
59 |
+
with tempfile.TemporaryDirectory(dir=self.target_dir) as tmpdirname:
|
60 |
+
tar.extractall(tmpdirname) # Extract to the temporary directory
|
61 |
+
|
62 |
+
# Move contents from temporary directory to final target directory
|
63 |
+
for src_dir, dirs, files in os.walk(tmpdirname):
|
64 |
+
dst_dir = src_dir.replace(tmpdirname, self.target_dir, 1)
|
65 |
+
os.makedirs(dst_dir, exist_ok=True)
|
66 |
+
for file_ in files:
|
67 |
+
src_file = os.path.join(src_dir, file_)
|
68 |
+
dst_file = os.path.join(dst_dir, file_)
|
69 |
+
shutil.move(src_file, dst_file)
|
70 |
+
|
71 |
+
# Remove the source tarball if the flag is set to True
|
72 |
+
if self.remove_source:
|
73 |
+
os.remove(self.tarballs[item])
|
74 |
+
|
75 |
+
return 0
|
76 |
+
|
77 |
+
def untar_all(archive_dir, target_dir, remove_source):
|
78 |
+
loader = DataLoader(UnTarballer(archive_dir, target_dir, remove_source), num_workers=24)
|
79 |
+
for _ in tqdm(loader):
|
80 |
+
pass
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
# loader = DataLoader(Tarballer(
|
85 |
+
# join("/pytorch-data", "audioset-raw", "audio"),
|
86 |
+
# join("/pytorch-data", "audioset-raw", "audio_archives")
|
87 |
+
# ), num_workers=24)
|
88 |
+
|
89 |
+
# loader = DataLoader(Tarballer(
|
90 |
+
# join("/pytorch-data", "audioset-raw", "frames"),
|
91 |
+
# join("/pytorch-data", "audioset-raw", "frame_archives"),
|
92 |
+
# 5000
|
93 |
+
# ), num_workers=24)
|
94 |
+
|
95 |
+
# loader = DataLoader(Tarballer(
|
96 |
+
# join("/pytorch-data", "ADE20KLabels"),
|
97 |
+
# join("/pytorch-data", "ADE20KLabelsAr"),
|
98 |
+
# 100
|
99 |
+
# ), num_workers=24)
|
100 |
+
#
|
101 |
+
# for _ in tqdm(loader):
|
102 |
+
# pass
|
103 |
+
#
|
104 |
+
# #
|
105 |
+
#
|
106 |
+
untar_all(
|
107 |
+
join("/pytorch-data", "audioset-raw", "frame_archives"),
|
108 |
+
join("/pytorch-data", "audioset-raw", "frames_4"))
|
DenseAV/denseav/eval_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torchmetrics.functional.classification import binary_average_precision
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from constants import *
|
12 |
+
from denseav.shared import unnorm, remove_axes
|
13 |
+
|
14 |
+
|
15 |
+
def prep_heatmap(sims, masks, h, w):
|
16 |
+
masks = masks.to(torch.float32)
|
17 |
+
hm = torch.einsum("bhwt,bt->bhw", sims, masks) / masks.sum(-1).reshape(-1, 1, 1)
|
18 |
+
hm -= hm.min()
|
19 |
+
hm /= hm.max()
|
20 |
+
return F.interpolate(hm.unsqueeze(1), (h, w), mode="bilinear").squeeze(1)
|
21 |
+
|
22 |
+
|
23 |
+
def iou(prediction, target):
|
24 |
+
prediction = prediction > 0.0
|
25 |
+
target = target > 0.5
|
26 |
+
intersection = torch.logical_and(prediction, target).sum().float()
|
27 |
+
union = torch.logical_or(prediction, target).sum().float()
|
28 |
+
if union == 0:
|
29 |
+
return 1.0
|
30 |
+
return (intersection / union).item() # Convert to Python scalar
|
31 |
+
|
32 |
+
|
33 |
+
def multi_iou(prediction, target, k=20):
|
34 |
+
prediction = torch.tensor(prediction)
|
35 |
+
target = torch.tensor(target)
|
36 |
+
target = target > 0.5
|
37 |
+
|
38 |
+
thresholds = torch.linspace(prediction.min(), prediction.max(), k)
|
39 |
+
hard_pred = prediction.unsqueeze(0) > thresholds.reshape(k, 1, 1, 1, 1)
|
40 |
+
target = torch.broadcast_to(target.unsqueeze(0), hard_pred.shape)
|
41 |
+
|
42 |
+
# Calculate IoU for each threshold
|
43 |
+
intersection = torch.logical_and(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
|
44 |
+
union = torch.logical_or(hard_pred, target).sum(dim=(1, 2, 3, 4)).float()
|
45 |
+
union = torch.where(union == 0, torch.tensor(1.0), union) # Avoid division by zero
|
46 |
+
iou_scores = intersection / union
|
47 |
+
|
48 |
+
# Find the best IoU and corresponding threshold
|
49 |
+
best_iou, best_idx = torch.max(iou_scores, dim=0)
|
50 |
+
# best_threshold = thresholds[best_idx]
|
51 |
+
# print(best_threshold)
|
52 |
+
return best_iou # , best_threshold.item()
|
53 |
+
|
54 |
+
|
55 |
+
def get_paired_heatmaps(
|
56 |
+
model,
|
57 |
+
results,
|
58 |
+
class_ids,
|
59 |
+
timing,
|
60 |
+
class_names=None):
|
61 |
+
sims = model.sim_agg.get_pairwise_sims(
|
62 |
+
results,
|
63 |
+
raw=False,
|
64 |
+
agg_sim=False,
|
65 |
+
agg_heads=True
|
66 |
+
).squeeze(1).mean(-2)
|
67 |
+
|
68 |
+
prompt_classes = torch.tensor(list(class_ids))
|
69 |
+
gt = results["semseg"] == prompt_classes.reshape(-1, 1, 1)
|
70 |
+
basic_masks = results[AUDIO_MASK] # BxT
|
71 |
+
_, fullh, fullw = gt.shape
|
72 |
+
basic_heatmaps = prep_heatmap(sims, basic_masks, fullh, fullw)
|
73 |
+
|
74 |
+
if timing is not None:
|
75 |
+
prompt_timing = np.array(list(timing))
|
76 |
+
raw_timing = torch.tensor([json.loads(t) for t in prompt_timing])
|
77 |
+
timing = torch.clone(raw_timing)
|
78 |
+
timing[:, 0] -= .2
|
79 |
+
timing[:, 1] += .2
|
80 |
+
total_length = (results['total_length'] / 16000)[0]
|
81 |
+
fracs = timing / total_length
|
82 |
+
bounds = basic_masks.shape[1] * fracs
|
83 |
+
bounds[:, 0] = bounds[:, 0].floor()
|
84 |
+
bounds[:, 1] = bounds[:, 1].ceil()
|
85 |
+
bounds = bounds.to(torch.int64)
|
86 |
+
advanced_masks = (F.one_hot(bounds, basic_masks.shape[1]).cumsum(-1).sum(-2) == 1).to(basic_masks)
|
87 |
+
advanced_heatmaps = prep_heatmap(sims, advanced_masks, fullh, fullw)
|
88 |
+
|
89 |
+
metrics = defaultdict(list)
|
90 |
+
unique_classes = torch.unique(prompt_classes)
|
91 |
+
|
92 |
+
should_plot = class_names is not None
|
93 |
+
|
94 |
+
if should_plot:
|
95 |
+
prompt_names = np.array(list(class_names))
|
96 |
+
|
97 |
+
for prompt_class in tqdm(unique_classes):
|
98 |
+
subset = torch.where(prompt_classes == prompt_class)[0]
|
99 |
+
gt_subset = gt[subset]
|
100 |
+
basic_subset = basic_heatmaps[subset]
|
101 |
+
metrics["basic_ap"].append(binary_average_precision(basic_subset.flatten(), gt_subset.flatten()))
|
102 |
+
metrics["basic_iou"].append(multi_iou(basic_subset.flatten(), gt_subset.flatten()))
|
103 |
+
|
104 |
+
if timing is not None:
|
105 |
+
advanced_subset = advanced_heatmaps[subset]
|
106 |
+
metrics["advanced_ap"].append(binary_average_precision(advanced_subset.flatten(), gt_subset.flatten()))
|
107 |
+
metrics["advanced_iou"].append(multi_iou(advanced_subset.flatten(), gt_subset.flatten()))
|
108 |
+
|
109 |
+
if should_plot:
|
110 |
+
prompt_class_subset = prompt_classes[subset]
|
111 |
+
name_subset = prompt_names[subset]
|
112 |
+
print(prompt_class, name_subset, prompt_class_subset)
|
113 |
+
n_imgs = min(len(subset), 5)
|
114 |
+
if n_imgs > 1:
|
115 |
+
fig, axes = plt.subplots(n_imgs, 5, figsize=(4 * 5, n_imgs * 3))
|
116 |
+
frame_subset = unnorm(results[IMAGE_INPUT][subset].squeeze(1)).permute(0, 2, 3, 1)
|
117 |
+
semseg_subset = results["semseg"][subset]
|
118 |
+
for img_num in range(n_imgs):
|
119 |
+
axes[img_num, 0].imshow(frame_subset[img_num])
|
120 |
+
axes[img_num, 1].imshow(basic_subset[img_num])
|
121 |
+
axes[img_num, 2].imshow(advanced_subset[img_num])
|
122 |
+
axes[img_num, 3].imshow(gt_subset[img_num])
|
123 |
+
axes[img_num, 4].imshow(semseg_subset[img_num], cmap="tab20", interpolation='none')
|
124 |
+
|
125 |
+
axes[0, 0].set_title("Image")
|
126 |
+
class_name = name_subset[0].split(",")[0]
|
127 |
+
axes[0, 1].set_title(f"{class_name} Basic Heatmap")
|
128 |
+
axes[0, 2].set_title(f"{class_name} Advanced Heatmap")
|
129 |
+
axes[0, 3].set_title("True Mask")
|
130 |
+
axes[0, 4].set_title("True Seg")
|
131 |
+
remove_axes(axes)
|
132 |
+
plt.tight_layout()
|
133 |
+
plt.show()
|
134 |
+
|
135 |
+
return metrics, unique_classes
|
DenseAV/denseav/evaluate.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from os.path import join
|
2 |
+
import hydra
|
3 |
+
from omegaconf import DictConfig, OmegaConf
|
4 |
+
from pytorch_lightning import Trainer
|
5 |
+
from pytorch_lightning import seed_everything
|
6 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
7 |
+
from denseav.data.AVDatasets import AVDataModule
|
8 |
+
from denseav.shared import load_trained_model
|
9 |
+
|
10 |
+
|
11 |
+
@hydra.main(config_path="configs", config_name="av_align.yaml")
|
12 |
+
def my_app(cfg: DictConfig) -> None:
|
13 |
+
from saved_models import saved_model_dict
|
14 |
+
|
15 |
+
seed_everything(0)
|
16 |
+
print(OmegaConf.to_yaml(cfg))
|
17 |
+
|
18 |
+
models_to_eval = [
|
19 |
+
"denseav_language",
|
20 |
+
"denseav_sound",
|
21 |
+
]
|
22 |
+
|
23 |
+
checkpoint_dir = "../checkpoints"
|
24 |
+
saved_models = saved_model_dict(checkpoint_dir)
|
25 |
+
for model_name in models_to_eval:
|
26 |
+
model_info = saved_models[model_name]
|
27 |
+
extra_data_args = model_info["data_args"] if "data_args" in model_info else {}
|
28 |
+
model_info["extra_args"]["output_root"] = "../"
|
29 |
+
model_info["extra_args"]["neg_audio"] = False
|
30 |
+
model_info["extra_args"]["image_mixup"] = 0.0
|
31 |
+
|
32 |
+
model = load_trained_model(join(checkpoint_dir, model_info["chkpt_name"]), model_info["extra_args"])
|
33 |
+
model.set_full_train(True)
|
34 |
+
|
35 |
+
if model.image_model_type == "dinov2":
|
36 |
+
load_size = cfg.load_size * 2
|
37 |
+
else:
|
38 |
+
load_size = cfg.load_size
|
39 |
+
|
40 |
+
if model.image_model_type == "davenet":
|
41 |
+
batch_size = cfg.batch_size // 2
|
42 |
+
elif model.image_model_type == "imagebind":
|
43 |
+
batch_size = cfg.batch_size
|
44 |
+
else:
|
45 |
+
batch_size = cfg.batch_size
|
46 |
+
|
47 |
+
print(load_size)
|
48 |
+
|
49 |
+
data_args = dict(
|
50 |
+
dataset_name=cfg.dataset_name,
|
51 |
+
load_size=load_size,
|
52 |
+
image_aug=cfg.image_aug,
|
53 |
+
audio_aug=cfg.audio_aug,
|
54 |
+
audio_model_type=model.audio_model_type,
|
55 |
+
pytorch_data_dir=cfg.pytorch_data_dir,
|
56 |
+
use_cached_embs=model.use_cached_embs,
|
57 |
+
batch_size=batch_size,
|
58 |
+
num_workers=cfg.num_workers,
|
59 |
+
extra_audio_masking=False,
|
60 |
+
use_original_val_set=False,
|
61 |
+
use_extra_val_sets=True,
|
62 |
+
use_caption=True,
|
63 |
+
data_for_plotting=False,
|
64 |
+
n_frames=None,
|
65 |
+
audio_level=False,
|
66 |
+
neg_audio=False,
|
67 |
+
quad_mixup=0.0,
|
68 |
+
bg_mixup=0.0,
|
69 |
+
patch_mixup=0.0,
|
70 |
+
patch_size=8,
|
71 |
+
)
|
72 |
+
data_args = {**data_args, **extra_data_args}
|
73 |
+
|
74 |
+
datamodule = AVDataModule(**data_args)
|
75 |
+
log_dir = join(cfg.output_root, "logs", "evaluate", model_name)
|
76 |
+
print(log_dir)
|
77 |
+
tb_logger = TensorBoardLogger(log_dir, default_hp_metric=False)
|
78 |
+
trainer = Trainer(
|
79 |
+
accelerator='gpu',
|
80 |
+
strategy="ddp",
|
81 |
+
devices=cfg.num_gpus,
|
82 |
+
logger=tb_logger)
|
83 |
+
trainer.validate(model, datamodule)
|
84 |
+
|
85 |
+
|
86 |
+
if __name__ == "__main__":
|
87 |
+
my_app()
|
DenseAV/denseav/featurizers/AudioMAE.py
ADDED
@@ -0,0 +1,570 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
from functools import partial
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchaudio
|
11 |
+
from timm.models.layers import to_2tuple
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
from torchaudio.functional import resample
|
14 |
+
import pickle
|
15 |
+
|
16 |
+
|
17 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
18 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
19 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
20 |
+
def norm_cdf(x):
|
21 |
+
# Computes standard normal cumulative distribution function
|
22 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
23 |
+
|
24 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
25 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
26 |
+
"The distribution of values may be incorrect.",
|
27 |
+
stacklevel=2)
|
28 |
+
|
29 |
+
with torch.no_grad():
|
30 |
+
# Values are generated by using a truncated uniform distribution and
|
31 |
+
# then using the inverse CDF for the normal distribution.
|
32 |
+
# Get upper and lower cdf values
|
33 |
+
l = norm_cdf((a - mean) / std)
|
34 |
+
u = norm_cdf((b - mean) / std)
|
35 |
+
|
36 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
37 |
+
# [2l-1, 2u-1].
|
38 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
39 |
+
|
40 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
41 |
+
# standard normal
|
42 |
+
tensor.erfinv_()
|
43 |
+
|
44 |
+
# Transform to proper mean, std
|
45 |
+
tensor.mul_(std * math.sqrt(2.))
|
46 |
+
tensor.add_(mean)
|
47 |
+
|
48 |
+
# Clamp to ensure it's in the proper range
|
49 |
+
tensor.clamp_(min=a, max=b)
|
50 |
+
return tensor
|
51 |
+
|
52 |
+
|
53 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
54 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
55 |
+
r"""Fills the input Tensor with values drawn from a truncated
|
56 |
+
normal distribution. The values are effectively drawn from the
|
57 |
+
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
58 |
+
with values outside :math:`[a, b]` redrawn until they are within
|
59 |
+
the bounds. The method used for generating the random values works
|
60 |
+
best when :math:`a \leq \text{mean} \leq b`.
|
61 |
+
Args:
|
62 |
+
tensor: an n-dimensional `torch.Tensor`
|
63 |
+
mean: the mean of the normal distribution
|
64 |
+
std: the standard deviation of the normal distribution
|
65 |
+
a: the minimum cutoff value
|
66 |
+
b: the maximum cutoff value
|
67 |
+
Examples:
|
68 |
+
>>> w = torch.empty(3, 5)
|
69 |
+
>>> nn.init.trunc_normal_(w)
|
70 |
+
"""
|
71 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
72 |
+
|
73 |
+
|
74 |
+
class Mlp(nn.Module):
|
75 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
76 |
+
super().__init__()
|
77 |
+
out_features = out_features or in_features
|
78 |
+
hidden_features = hidden_features or in_features
|
79 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
80 |
+
self.act = act_layer()
|
81 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
82 |
+
self.drop = nn.Dropout(drop)
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
x = self.fc1(x)
|
86 |
+
x = self.act(x)
|
87 |
+
x = self.drop(x)
|
88 |
+
x = self.fc2(x)
|
89 |
+
x = self.drop(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class Attention(nn.Module):
|
94 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
95 |
+
super().__init__()
|
96 |
+
self.num_heads = num_heads
|
97 |
+
head_dim = dim // num_heads
|
98 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
99 |
+
self.scale = qk_scale or head_dim ** -0.5
|
100 |
+
|
101 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
102 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
103 |
+
self.proj = nn.Linear(dim, dim)
|
104 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
B, N, C = x.shape
|
108 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
109 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
110 |
+
|
111 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
112 |
+
attn = attn.softmax(dim=-1)
|
113 |
+
attn = self.attn_drop(attn)
|
114 |
+
|
115 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
116 |
+
x = self.proj(x)
|
117 |
+
x = self.proj_drop(x)
|
118 |
+
return x
|
119 |
+
|
120 |
+
|
121 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
122 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
123 |
+
|
124 |
+
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
125 |
+
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
126 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
127 |
+
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
128 |
+
'survival rate' as the argument.
|
129 |
+
|
130 |
+
"""
|
131 |
+
if drop_prob == 0. or not training:
|
132 |
+
return x
|
133 |
+
keep_prob = 1 - drop_prob
|
134 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
135 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
136 |
+
random_tensor.floor_() # binarize
|
137 |
+
output = x.div(keep_prob) * random_tensor
|
138 |
+
return output
|
139 |
+
|
140 |
+
|
141 |
+
class DropPath(nn.Module):
|
142 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, drop_prob=None):
|
146 |
+
super(DropPath, self).__init__()
|
147 |
+
self.drop_prob = drop_prob
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
return drop_path(x, self.drop_prob, self.training)
|
151 |
+
|
152 |
+
|
153 |
+
class Block(nn.Module):
|
154 |
+
|
155 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
156 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
157 |
+
super().__init__()
|
158 |
+
self.norm1 = norm_layer(dim)
|
159 |
+
self.attn = Attention(
|
160 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
161 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
162 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
163 |
+
self.norm2 = norm_layer(dim)
|
164 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
165 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
166 |
+
|
167 |
+
def forward(self, x):
|
168 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
169 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
170 |
+
return x
|
171 |
+
|
172 |
+
|
173 |
+
class PatchEmbed(nn.Module):
|
174 |
+
""" Image to Patch Embedding
|
175 |
+
"""
|
176 |
+
|
177 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
178 |
+
super().__init__()
|
179 |
+
img_size = to_2tuple(img_size)
|
180 |
+
patch_size = to_2tuple(patch_size)
|
181 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
182 |
+
self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
|
183 |
+
self.img_size = img_size
|
184 |
+
self.patch_size = patch_size
|
185 |
+
self.num_patches = num_patches
|
186 |
+
|
187 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
188 |
+
|
189 |
+
def forward(self, x):
|
190 |
+
B, C, H, W = x.shape
|
191 |
+
# FIXME look at relaxing size constraints
|
192 |
+
# assert H == self.img_size[0] and W == self.img_size[1], \
|
193 |
+
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
194 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
195 |
+
return x
|
196 |
+
|
197 |
+
|
198 |
+
class HybridEmbed(nn.Module):
|
199 |
+
""" CNN Feature Map Embedding
|
200 |
+
Extract feature map from CNN, flatten, project to embedding dim.
|
201 |
+
"""
|
202 |
+
|
203 |
+
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
204 |
+
super().__init__()
|
205 |
+
assert isinstance(backbone, nn.Module)
|
206 |
+
img_size = to_2tuple(img_size)
|
207 |
+
self.img_size = img_size
|
208 |
+
self.backbone = backbone
|
209 |
+
if feature_size is None:
|
210 |
+
with torch.no_grad():
|
211 |
+
# FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
|
212 |
+
# map for all networks, the feature metadata has reliable channel and stride info, but using
|
213 |
+
# stride to calc feature dim requires info about padding of each stage that isn't captured.
|
214 |
+
training = backbone.training
|
215 |
+
if training:
|
216 |
+
backbone.eval()
|
217 |
+
o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1]
|
218 |
+
feature_size = o.shape[-2:]
|
219 |
+
feature_dim = o.shape[1]
|
220 |
+
backbone.train(training)
|
221 |
+
else:
|
222 |
+
feature_size = to_2tuple(feature_size)
|
223 |
+
feature_dim = self.backbone.feature_info.channels()[-1]
|
224 |
+
self.num_patches = feature_size[0] * feature_size[1]
|
225 |
+
self.proj = nn.Linear(feature_dim, embed_dim)
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
x = self.backbone(x)[-1]
|
229 |
+
x = x.flatten(2).transpose(1, 2)
|
230 |
+
x = self.proj(x)
|
231 |
+
return x
|
232 |
+
|
233 |
+
|
234 |
+
class TimmVisionTransformer(nn.Module):
|
235 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
239 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
240 |
+
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
241 |
+
super().__init__()
|
242 |
+
self.num_classes = num_classes
|
243 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
244 |
+
|
245 |
+
if hybrid_backbone is not None:
|
246 |
+
self.patch_embed = HybridEmbed(
|
247 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
248 |
+
else:
|
249 |
+
self.patch_embed = PatchEmbed(
|
250 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
251 |
+
num_patches = self.patch_embed.num_patches
|
252 |
+
|
253 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
254 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
255 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
256 |
+
|
257 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
258 |
+
self.blocks = nn.ModuleList([
|
259 |
+
Block(
|
260 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
261 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
262 |
+
for i in range(depth)])
|
263 |
+
self.norm = norm_layer(embed_dim)
|
264 |
+
|
265 |
+
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
266 |
+
# self.repr = nn.Linear(embed_dim, representation_size)
|
267 |
+
# self.repr_act = nn.Tanh()
|
268 |
+
|
269 |
+
# Classifier head
|
270 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
271 |
+
|
272 |
+
trunc_normal_(self.pos_embed, std=.02)
|
273 |
+
trunc_normal_(self.cls_token, std=.02)
|
274 |
+
self.apply(self._init_weights)
|
275 |
+
|
276 |
+
def _init_weights(self, m):
|
277 |
+
if isinstance(m, nn.Linear):
|
278 |
+
trunc_normal_(m.weight, std=.02)
|
279 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
280 |
+
nn.init.constant_(m.bias, 0)
|
281 |
+
elif isinstance(m, nn.LayerNorm):
|
282 |
+
nn.init.constant_(m.bias, 0)
|
283 |
+
nn.init.constant_(m.weight, 1.0)
|
284 |
+
|
285 |
+
@torch.jit.ignore
|
286 |
+
def no_weight_decay(self):
|
287 |
+
return {'pos_embed', 'cls_token'}
|
288 |
+
|
289 |
+
def get_classifier(self):
|
290 |
+
return self.head
|
291 |
+
|
292 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
293 |
+
self.num_classes = num_classes
|
294 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
295 |
+
|
296 |
+
def forward_features(self, x):
|
297 |
+
B = x.shape[0]
|
298 |
+
x = self.patch_embed(x)
|
299 |
+
|
300 |
+
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
301 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
302 |
+
x = x + self.pos_embed
|
303 |
+
x = self.pos_drop(x)
|
304 |
+
|
305 |
+
for blk in self.blocks:
|
306 |
+
x = blk(x)
|
307 |
+
|
308 |
+
x = self.norm(x)
|
309 |
+
return x[:, 0]
|
310 |
+
|
311 |
+
def forward(self, x):
|
312 |
+
x = self.forward_features(x)
|
313 |
+
x = self.head(x)
|
314 |
+
return x
|
315 |
+
|
316 |
+
|
317 |
+
class VisionTransformer(TimmVisionTransformer):
|
318 |
+
""" Vision Transformer with support for global average pooling
|
319 |
+
"""
|
320 |
+
|
321 |
+
def __init__(self, **kwargs):
|
322 |
+
super(VisionTransformer, self).__init__(**kwargs)
|
323 |
+
norm_layer = kwargs['norm_layer']
|
324 |
+
embed_dim = kwargs['embed_dim']
|
325 |
+
self.fc_norm = norm_layer(embed_dim)
|
326 |
+
del self.norm # remove the original norm
|
327 |
+
|
328 |
+
def interpolate_pos_encoding(self, x, embed):
|
329 |
+
new_patches = x.shape[1]
|
330 |
+
old_patches = embed.shape[1]
|
331 |
+
|
332 |
+
w = 8
|
333 |
+
h = int(new_patches / w)
|
334 |
+
if new_patches == old_patches:
|
335 |
+
return embed
|
336 |
+
|
337 |
+
dim = x.shape[-1]
|
338 |
+
pos_embed = nn.functional.interpolate(
|
339 |
+
embed.reshape(1, 64, 8, dim).permute(0, 3, 1, 2),
|
340 |
+
size=(h, w),
|
341 |
+
mode='bicubic',
|
342 |
+
)
|
343 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
344 |
+
return pos_embed
|
345 |
+
|
346 |
+
def forward(self, x):
|
347 |
+
B = x.shape[0]
|
348 |
+
x = self.patch_embed(x)
|
349 |
+
|
350 |
+
x = x + self.interpolate_pos_encoding(x, self.pos_embed[:, 1:, :])
|
351 |
+
|
352 |
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
353 |
+
cls_tokens = cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
354 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
355 |
+
x = self.pos_drop(x)
|
356 |
+
|
357 |
+
for blk in self.blocks:
|
358 |
+
x = blk(x)
|
359 |
+
|
360 |
+
# x = x[:, 1:, :].mean(dim=1) # global pool without cls token
|
361 |
+
# outcome = self.fc_norm(x)
|
362 |
+
|
363 |
+
return x[:, 1:, :].reshape(B, -1, 8, 768).permute(0, 3, 2, 1), x[:, 0]
|
364 |
+
|
365 |
+
|
366 |
+
class NewPatchEmbed(nn.Module):
|
367 |
+
""" Flexible Image to Patch Embedding
|
368 |
+
"""
|
369 |
+
|
370 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10):
|
371 |
+
super().__init__()
|
372 |
+
img_size = to_2tuple(img_size)
|
373 |
+
patch_size = to_2tuple(patch_size)
|
374 |
+
stride = to_2tuple(stride)
|
375 |
+
self.img_size = img_size
|
376 |
+
self.patch_size = patch_size
|
377 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
|
378 |
+
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
|
379 |
+
self.patch_hw = (h, w)
|
380 |
+
self.num_patches = h * w
|
381 |
+
|
382 |
+
def get_output_shape(self, img_size):
|
383 |
+
# todo: don't be lazy..
|
384 |
+
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
|
385 |
+
|
386 |
+
def forward(self, x):
|
387 |
+
x = self.proj(x)
|
388 |
+
x = x.flatten(2).transpose(1, 2)
|
389 |
+
return x
|
390 |
+
|
391 |
+
|
392 |
+
def pca(image_feats_list, dim=3, fit_pca=None):
|
393 |
+
from sklearn.decomposition import PCA
|
394 |
+
|
395 |
+
device = image_feats_list[0].device
|
396 |
+
|
397 |
+
def flatten(tensor, target_size=None):
|
398 |
+
if target_size is not None and fit_pca is None:
|
399 |
+
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
|
400 |
+
B, C, H, W = tensor.shape
|
401 |
+
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
|
402 |
+
|
403 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
404 |
+
target_size = image_feats_list[0].shape[2]
|
405 |
+
else:
|
406 |
+
target_size = None
|
407 |
+
|
408 |
+
flattened_feats = []
|
409 |
+
for feats in image_feats_list:
|
410 |
+
flattened_feats.append(flatten(feats, target_size))
|
411 |
+
x = torch.cat(flattened_feats, dim=0)
|
412 |
+
|
413 |
+
if fit_pca is None:
|
414 |
+
fit_pca = PCA(n_components=dim, svd_solver="arpack").fit(np.nan_to_num(x.detach().numpy()))
|
415 |
+
|
416 |
+
reduced_feats = []
|
417 |
+
for feats in image_feats_list:
|
418 |
+
x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
|
419 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
420 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
421 |
+
B, C, H, W = feats.shape
|
422 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
|
423 |
+
|
424 |
+
return reduced_feats, fit_pca
|
425 |
+
|
426 |
+
|
427 |
+
class AudiosetDataset(Dataset):
|
428 |
+
def __init__(self, audio_conf):
|
429 |
+
self.audio_conf = audio_conf
|
430 |
+
self.melbins = self.audio_conf.get('num_mel_bins')
|
431 |
+
self.dataset = self.audio_conf.get('dataset')
|
432 |
+
self.norm_mean = self.audio_conf.get('mean')
|
433 |
+
self.norm_std = self.audio_conf.get('std')
|
434 |
+
|
435 |
+
print('Dataset: {}, mean {:.3f} and std {:.3f}'.format(self.dataset, self.norm_mean, self.norm_std))
|
436 |
+
print(f'size of dataset {self.__len__()}')
|
437 |
+
|
438 |
+
def _wav2fbank(self, filename):
|
439 |
+
sample_rate = 16000
|
440 |
+
target_length = 10
|
441 |
+
waveform, obs_sr = torchaudio.load(filename)
|
442 |
+
waveform = waveform[0]
|
443 |
+
if obs_sr != sample_rate:
|
444 |
+
waveform = resample(waveform, obs_sr, sample_rate)
|
445 |
+
|
446 |
+
original_length = waveform.shape[0]
|
447 |
+
padding = target_length * sample_rate - original_length
|
448 |
+
|
449 |
+
if padding > 0:
|
450 |
+
m = torch.nn.ZeroPad2d((0, padding))
|
451 |
+
waveform = m(waveform)
|
452 |
+
else:
|
453 |
+
waveform = waveform[:target_length * sample_rate]
|
454 |
+
|
455 |
+
|
456 |
+
waveform = waveform - waveform.mean()
|
457 |
+
|
458 |
+
# 498 128, 998, 128
|
459 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
460 |
+
waveform.unsqueeze(0),
|
461 |
+
htk_compat=True,
|
462 |
+
sample_frequency=sample_rate,
|
463 |
+
use_energy=False,
|
464 |
+
window_type='hanning',
|
465 |
+
num_mel_bins=128,
|
466 |
+
dither=0.0,
|
467 |
+
frame_shift=10)
|
468 |
+
|
469 |
+
normed_fbank = (fbank - self.norm_mean) / (self.norm_std * 2)
|
470 |
+
|
471 |
+
return normed_fbank
|
472 |
+
|
473 |
+
def __getitem__(self, index):
|
474 |
+
datum = {"wav": "../../samples/example.wav"}
|
475 |
+
fbank = self._wav2fbank(datum['wav'])
|
476 |
+
fbank = fbank.transpose(0, 1).unsqueeze(0) # 1, 128, 1024 (...,freq,time)
|
477 |
+
fbank = torch.transpose(fbank.squeeze(), 0, 1) # time, freq
|
478 |
+
# the output fbank shape is [time_frame_num, frequency_bins], e.g., [1024, 128]
|
479 |
+
return fbank.unsqueeze(0)
|
480 |
+
|
481 |
+
def __len__(self):
|
482 |
+
return 1
|
483 |
+
|
484 |
+
|
485 |
+
class AudioMAE(nn.Module):
|
486 |
+
|
487 |
+
def __init__(self, output_path, finetuned):
|
488 |
+
super().__init__()
|
489 |
+
# build model
|
490 |
+
model = VisionTransformer(
|
491 |
+
patch_size=16,
|
492 |
+
embed_dim=768,
|
493 |
+
depth=12,
|
494 |
+
num_heads=12,
|
495 |
+
mlp_ratio=4,
|
496 |
+
qkv_bias=True,
|
497 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
498 |
+
num_classes=527,
|
499 |
+
drop_path_rate=0.1)
|
500 |
+
|
501 |
+
img_size = (1024, 128) # 1024, 128
|
502 |
+
emb_dim = 768
|
503 |
+
model.patch_embed = NewPatchEmbed(
|
504 |
+
img_size=img_size, patch_size=(16, 16), in_chans=1, embed_dim=emb_dim, stride=16)
|
505 |
+
num_patches = model.patch_embed.num_patches
|
506 |
+
model.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, emb_dim), requires_grad=False)
|
507 |
+
|
508 |
+
if finetuned:
|
509 |
+
fn = "audiomae_finetuned.pth"
|
510 |
+
else:
|
511 |
+
fn = "audiomae.pth"
|
512 |
+
|
513 |
+
checkpoint = torch.load(os.path.join(output_path, 'models', fn), map_location='cpu')
|
514 |
+
|
515 |
+
checkpoint_model = checkpoint['model']
|
516 |
+
state_dict = model.state_dict()
|
517 |
+
for k in ['head.weight', 'head.bias']:
|
518 |
+
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
|
519 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
520 |
+
del checkpoint_model[k]
|
521 |
+
msg = model.load_state_dict(checkpoint_model, strict=False)
|
522 |
+
print(msg)
|
523 |
+
|
524 |
+
model = model.eval()
|
525 |
+
self.model = model
|
526 |
+
self.config = dict(output_path=output_path, finetuned=finetuned)
|
527 |
+
|
528 |
+
def forward(self, audio, include_cls):
|
529 |
+
patch_tokens, cls_token = self.model(audio)
|
530 |
+
|
531 |
+
if include_cls:
|
532 |
+
return patch_tokens, cls_token
|
533 |
+
else:
|
534 |
+
return patch_tokens
|
535 |
+
|
536 |
+
|
537 |
+
if __name__ == '__main__':
|
538 |
+
import os
|
539 |
+
|
540 |
+
device = torch.device("cuda:2")
|
541 |
+
|
542 |
+
torch.manual_seed(0)
|
543 |
+
np.random.seed(0)
|
544 |
+
|
545 |
+
model = AudioMAE("../../", True).to(device)
|
546 |
+
|
547 |
+
audio_conf_val = {
|
548 |
+
'num_mel_bins': 128,
|
549 |
+
'target_length': 1024,
|
550 |
+
'dataset': "audioset",
|
551 |
+
'mode': 'val',
|
552 |
+
'mean': -4.2677393,
|
553 |
+
'std': 4.5689974,
|
554 |
+
}
|
555 |
+
|
556 |
+
dataset = AudiosetDataset(audio_conf=audio_conf_val)
|
557 |
+
|
558 |
+
batch = dataset[0].unsqueeze(0).to(device)
|
559 |
+
|
560 |
+
embeddings = model(batch, include_cls=False)
|
561 |
+
|
562 |
+
import matplotlib.pyplot as plt
|
563 |
+
|
564 |
+
with torch.no_grad():
|
565 |
+
[pca_feats], _ = pca([embeddings])
|
566 |
+
plt.imshow(pca_feats.cpu().squeeze(0).permute(1, 2, 0))
|
567 |
+
plt.show()
|
568 |
+
print("here")
|
569 |
+
|
570 |
+
print("here")
|
DenseAV/denseav/featurizers/CAVMAE.py
ADDED
@@ -0,0 +1,1082 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import timm
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torchaudio
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from PIL import Image
|
11 |
+
from timm.models.layers import to_2tuple, DropPath
|
12 |
+
from timm.models.vision_transformer import Mlp, PatchEmbed, Block
|
13 |
+
import os
|
14 |
+
|
15 |
+
|
16 |
+
class Attention(nn.Module):
|
17 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
18 |
+
super().__init__()
|
19 |
+
self.num_heads = num_heads
|
20 |
+
head_dim = dim // num_heads
|
21 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
22 |
+
self.scale = qk_scale or head_dim ** -0.5
|
23 |
+
|
24 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
25 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
26 |
+
self.proj = nn.Linear(dim, dim)
|
27 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
B, N, C = x.shape
|
31 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
32 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
33 |
+
|
34 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
35 |
+
attn = attn.softmax(dim=-1)
|
36 |
+
attn = self.attn_drop(attn)
|
37 |
+
|
38 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
39 |
+
x = self.proj(x)
|
40 |
+
x = self.proj_drop(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_h_size, grid_w_size, cls_token=False):
|
45 |
+
"""
|
46 |
+
grid_size: int of the grid height and width
|
47 |
+
return:
|
48 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
49 |
+
"""
|
50 |
+
grid_h = np.arange(grid_h_size, dtype=float)
|
51 |
+
grid_w = np.arange(grid_w_size, dtype=float)
|
52 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
53 |
+
grid = np.stack(grid, axis=0)
|
54 |
+
|
55 |
+
grid = grid.reshape([2, 1, grid_w_size, grid_h_size])
|
56 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
57 |
+
if cls_token:
|
58 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
59 |
+
return pos_embed
|
60 |
+
|
61 |
+
|
62 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
63 |
+
assert embed_dim % 2 == 0
|
64 |
+
|
65 |
+
# use half of dimensions to encode grid_h
|
66 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
67 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
68 |
+
|
69 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
70 |
+
return emb
|
71 |
+
|
72 |
+
|
73 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
74 |
+
"""
|
75 |
+
embed_dim: output dimension for each position
|
76 |
+
pos: a list of positions to be encoded: size (M,)
|
77 |
+
out: (M, D)
|
78 |
+
"""
|
79 |
+
assert embed_dim % 2 == 0
|
80 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
81 |
+
omega /= embed_dim / 2.
|
82 |
+
omega = 1. / 10000 ** omega # (D/2,)
|
83 |
+
|
84 |
+
pos = pos.reshape(-1) # (M,)
|
85 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
86 |
+
|
87 |
+
emb_sin = np.sin(out) # (M, D/2)
|
88 |
+
emb_cos = np.cos(out) # (M, D/2)
|
89 |
+
|
90 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
91 |
+
return emb
|
92 |
+
|
93 |
+
|
94 |
+
# --------------------------------------------------------
|
95 |
+
# Interpolate position embeddings for high-resolution
|
96 |
+
# References:
|
97 |
+
# DeiT: https://github.com/facebookresearch/deit
|
98 |
+
# --------------------------------------------------------
|
99 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
100 |
+
if 'pos_embed' in checkpoint_model:
|
101 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
102 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
103 |
+
num_patches = model.patch_embed.num_patches
|
104 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
105 |
+
# height (== width) for the checkpoint position embedding
|
106 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
107 |
+
# height (== width) for the new position embedding
|
108 |
+
new_size = int(num_patches ** 0.5)
|
109 |
+
# class_token and dist_token are kept unchanged
|
110 |
+
if orig_size != new_size:
|
111 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
112 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
113 |
+
# only the position tokens are interpolated
|
114 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
115 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
116 |
+
pos_tokens = torch.nn.functional.interpolate(
|
117 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
118 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
119 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
120 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
121 |
+
|
122 |
+
|
123 |
+
class PatchEmbed(nn.Module):
|
124 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
img_size = to_2tuple(img_size)
|
128 |
+
patch_size = to_2tuple(patch_size)
|
129 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
130 |
+
self.img_size = img_size
|
131 |
+
self.patch_size = patch_size
|
132 |
+
self.num_patches = num_patches
|
133 |
+
|
134 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
135 |
+
|
136 |
+
def forward(self, x):
|
137 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class Block(nn.Module):
|
142 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
143 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
144 |
+
super().__init__()
|
145 |
+
self.norm1 = norm_layer(dim)
|
146 |
+
self.norm1_a = norm_layer(dim)
|
147 |
+
self.norm1_v = norm_layer(dim)
|
148 |
+
self.attn = Attention(
|
149 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
150 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
151 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
152 |
+
self.norm2 = norm_layer(dim)
|
153 |
+
self.norm2_a = norm_layer(dim)
|
154 |
+
self.norm2_v = norm_layer(dim)
|
155 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
156 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
157 |
+
|
158 |
+
def forward(self, x, modality=None):
|
159 |
+
if modality == None:
|
160 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
161 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
162 |
+
elif modality == 'a':
|
163 |
+
x = x + self.drop_path(self.attn(self.norm1_a(x)))
|
164 |
+
x = x + self.drop_path(self.mlp(self.norm2_a(x)))
|
165 |
+
elif modality == 'v':
|
166 |
+
x = x + self.drop_path(self.attn(self.norm1_v(x)))
|
167 |
+
x = x + self.drop_path(self.mlp(self.norm2_v(x)))
|
168 |
+
return x
|
169 |
+
|
170 |
+
|
171 |
+
# our main proposed model, for pretraining only, for finetuning, use CAVMAEFT class
|
172 |
+
class CAVMAE(nn.Module):
|
173 |
+
""" CAV-MAE Model
|
174 |
+
"""
|
175 |
+
|
176 |
+
def __init__(self, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
|
177 |
+
embed_dim=768, modality_specific_depth=11, num_heads=12,
|
178 |
+
decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
|
179 |
+
mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False, tr_pos=False):
|
180 |
+
super().__init__()
|
181 |
+
print('A CAV-MAE Model')
|
182 |
+
print('Use norm_pix_loss: ', norm_pix_loss)
|
183 |
+
print('Learnable Positional Embedding: ', tr_pos)
|
184 |
+
|
185 |
+
# the encoder part
|
186 |
+
# overide the timm package
|
187 |
+
timm.models.vision_transformer.PatchEmbed = PatchEmbed
|
188 |
+
timm.models.vision_transformer.Block = Block
|
189 |
+
|
190 |
+
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
|
191 |
+
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
192 |
+
|
193 |
+
self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
|
194 |
+
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
|
195 |
+
self.patch_embed_v.num_patches))
|
196 |
+
|
197 |
+
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
198 |
+
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
199 |
+
|
200 |
+
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
|
201 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
202 |
+
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
|
203 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
204 |
+
|
205 |
+
# audio-branch
|
206 |
+
self.blocks_a = nn.ModuleList(
|
207 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
208 |
+
range(modality_specific_depth)])
|
209 |
+
# visual-branch
|
210 |
+
self.blocks_v = nn.ModuleList(
|
211 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
212 |
+
range(modality_specific_depth)])
|
213 |
+
# unified branch
|
214 |
+
self.blocks_u = nn.ModuleList(
|
215 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
216 |
+
range(12 - modality_specific_depth)])
|
217 |
+
|
218 |
+
# independent normalization layer for audio, visual, and audio-visual
|
219 |
+
self.norm_a, self.norm_v, self.norm = norm_layer(embed_dim), norm_layer(embed_dim), norm_layer(embed_dim)
|
220 |
+
|
221 |
+
# the decoder part
|
222 |
+
# Project to lower dimension for the decoder
|
223 |
+
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
|
224 |
+
|
225 |
+
# token used for masking
|
226 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
227 |
+
|
228 |
+
self.decoder_modality_a = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
229 |
+
self.decoder_modality_v = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
230 |
+
|
231 |
+
self.decoder_pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, decoder_embed_dim),
|
232 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
233 |
+
self.decoder_pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, decoder_embed_dim),
|
234 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
235 |
+
|
236 |
+
self.decoder_blocks = nn.ModuleList(
|
237 |
+
[Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer)
|
238 |
+
for i in range(decoder_depth)])
|
239 |
+
|
240 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
241 |
+
|
242 |
+
# project channel is different for two modality, use two projection head
|
243 |
+
self.decoder_pred_a = nn.Linear(decoder_embed_dim, patch_size ** 2 * 1, bias=True) # decoder to patch
|
244 |
+
self.decoder_pred_v = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) # decoder to patch
|
245 |
+
|
246 |
+
self.norm_pix_loss = norm_pix_loss
|
247 |
+
|
248 |
+
self.initialize_weights()
|
249 |
+
|
250 |
+
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
|
251 |
+
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
|
252 |
+
|
253 |
+
def initialize_weights(self):
|
254 |
+
# initialize (and freeze) pos_embed by sin-cos embedding, opt the cls token, add by myself
|
255 |
+
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
|
256 |
+
cls_token=False)
|
257 |
+
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
|
258 |
+
|
259 |
+
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
|
260 |
+
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
|
261 |
+
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
|
262 |
+
|
263 |
+
decoder_pos_embed_a = get_2d_sincos_pos_embed(self.decoder_pos_embed_a.shape[-1], 8,
|
264 |
+
int(self.patch_embed_a.num_patches / 8), cls_token=False)
|
265 |
+
self.decoder_pos_embed_a.data.copy_(torch.from_numpy(decoder_pos_embed_a).float().unsqueeze(0))
|
266 |
+
|
267 |
+
decoder_pos_embed_v = get_2d_sincos_pos_embed(self.decoder_pos_embed_v.shape[-1],
|
268 |
+
int(self.patch_embed_v.num_patches ** .5),
|
269 |
+
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
|
270 |
+
self.decoder_pos_embed_v.data.copy_(torch.from_numpy(decoder_pos_embed_v).float().unsqueeze(0))
|
271 |
+
|
272 |
+
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
|
273 |
+
w = self.patch_embed_a.proj.weight.data
|
274 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
275 |
+
w = self.patch_embed_v.proj.weight.data
|
276 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
277 |
+
|
278 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
279 |
+
torch.nn.init.normal_(self.modality_a, std=.02)
|
280 |
+
torch.nn.init.normal_(self.modality_v, std=.02)
|
281 |
+
torch.nn.init.normal_(self.decoder_modality_a, std=.02)
|
282 |
+
torch.nn.init.normal_(self.decoder_modality_v, std=.02)
|
283 |
+
torch.nn.init.normal_(self.mask_token, std=.02)
|
284 |
+
|
285 |
+
# initialize nn.Linear and nn.LayerNorm
|
286 |
+
self.apply(self._init_weights)
|
287 |
+
|
288 |
+
def _init_weights(self, m):
|
289 |
+
if isinstance(m, nn.Linear):
|
290 |
+
# we use xavier_uniform following official JAX ViT:
|
291 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
292 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
293 |
+
nn.init.constant_(m.bias, 0)
|
294 |
+
elif isinstance(m, nn.LayerNorm):
|
295 |
+
nn.init.constant_(m.bias, 0)
|
296 |
+
nn.init.constant_(m.weight, 1.0)
|
297 |
+
|
298 |
+
def patchify(self, imgs, c, h, w, p=16):
|
299 |
+
"""
|
300 |
+
imgs: (N, 3, H, W)
|
301 |
+
x: (N, L, patch_size**2 *3)
|
302 |
+
"""
|
303 |
+
x = imgs.reshape(shape=(imgs.shape[0], c, h, p, w, p))
|
304 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
305 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * c))
|
306 |
+
return x
|
307 |
+
|
308 |
+
def unpatchify(self, x, c, h, w, p=16):
|
309 |
+
"""
|
310 |
+
x: (N, L, patch_size**2 *3)
|
311 |
+
imgs: (N, 3, H, W)
|
312 |
+
"""
|
313 |
+
assert h * w == x.shape[1]
|
314 |
+
|
315 |
+
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
316 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
317 |
+
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
318 |
+
return imgs
|
319 |
+
|
320 |
+
def random_masking_unstructured(self, x, mask_ratio):
|
321 |
+
"""
|
322 |
+
Perform per-sample random masking by per-sample shuffling.
|
323 |
+
Per-sample shuffling is done by argsort random noise.
|
324 |
+
x: [N, L, D], sequence
|
325 |
+
"""
|
326 |
+
N, L, D = x.shape # batch, length, dim
|
327 |
+
len_keep = int(L * (1 - mask_ratio))
|
328 |
+
|
329 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
330 |
+
|
331 |
+
# sort noise for each sample
|
332 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
333 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
334 |
+
|
335 |
+
# keep the first subset
|
336 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
337 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
338 |
+
|
339 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
340 |
+
mask = torch.ones([N, L], device=x.device)
|
341 |
+
mask[:, :len_keep] = 0
|
342 |
+
# unshuffle to get the binary mask
|
343 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
344 |
+
|
345 |
+
return x_masked, mask, ids_restore
|
346 |
+
|
347 |
+
def random_masking_structured(self, x, mask_ratio, t=64, f=8, mode='time'):
|
348 |
+
"""
|
349 |
+
Perform per-sample random masking by per-sample shuffling.
|
350 |
+
Per-sample shuffling is done by argsort random noise.
|
351 |
+
x: [N, L, D], sequence
|
352 |
+
"""
|
353 |
+
N, L, D = x.shape # batch, length, dim
|
354 |
+
len_keep = int(L * (1 - mask_ratio))
|
355 |
+
|
356 |
+
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
|
357 |
+
assert L == f * t
|
358 |
+
noise = noise.reshape(N, f, t) # the audio patch is in shape [f,t], not [t,f]
|
359 |
+
if mode == 'time':
|
360 |
+
for i in range(N):
|
361 |
+
mask_t_list = random.sample(range(t), int(t * mask_ratio))
|
362 |
+
for k in mask_t_list:
|
363 |
+
noise[i, :, k] = 1.1 # large value will be removed
|
364 |
+
elif mode == 'freq':
|
365 |
+
for i in range(N):
|
366 |
+
mask_f_list = random.sample(range(f), int(f * mask_ratio))
|
367 |
+
for k in mask_f_list:
|
368 |
+
noise[i, k, :] = 1.1 # large value will be removed
|
369 |
+
elif mode == 'tf':
|
370 |
+
for i in range(N):
|
371 |
+
mask_t_list = random.sample(range(t), int(t * mask_ratio * 0.7))
|
372 |
+
for k in mask_t_list:
|
373 |
+
noise[i, :, k] = 1.1 # large value will be removed
|
374 |
+
for i in range(N):
|
375 |
+
mask_f_list = random.sample(range(f), int(f * mask_ratio * 0.7))
|
376 |
+
for k in mask_f_list:
|
377 |
+
noise[i, k, :] = 1.1 # large value will be removed
|
378 |
+
noise = noise.reshape(N, L)
|
379 |
+
|
380 |
+
# sort noise for each sample, only need to manuplate these two ids_shuffle, ids_restore
|
381 |
+
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
|
382 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
383 |
+
|
384 |
+
# keep the first subset
|
385 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
386 |
+
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
|
387 |
+
|
388 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
389 |
+
mask = torch.ones([N, L], device=x.device)
|
390 |
+
mask[:, :len_keep] = 0
|
391 |
+
# unshuffle to get the binary mask
|
392 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
393 |
+
|
394 |
+
return x_masked, mask, ids_restore
|
395 |
+
|
396 |
+
def forward_encoder(self, a, v, mask_ratio_a, mask_ratio_v, mask_mode='unstructured'):
|
397 |
+
# embed patches
|
398 |
+
a = a.unsqueeze(1)
|
399 |
+
a = a.transpose(2, 3)
|
400 |
+
a = self.patch_embed_a(a)
|
401 |
+
a = a + self.pos_embed_a
|
402 |
+
a = a + self.modality_a
|
403 |
+
|
404 |
+
v = self.patch_embed_v(v)
|
405 |
+
v = v + self.pos_embed_v
|
406 |
+
v = v + self.modality_v
|
407 |
+
|
408 |
+
# by default, we always use unstructured masking
|
409 |
+
if mask_mode == 'unstructured':
|
410 |
+
a, mask_a, ids_restore_a = self.random_masking_unstructured(a, mask_ratio_a)
|
411 |
+
# in ablation study, we tried time/freq/tf masking. mode in ['freq', 'time', 'tf']
|
412 |
+
else:
|
413 |
+
a, mask_a, ids_restore_a = self.random_masking_structured(a, mask_ratio_a, t=64, f=8, mode=mask_mode)
|
414 |
+
|
415 |
+
# visual branch always use unstructured masking
|
416 |
+
v, mask_v, ids_restore_v = self.random_masking_unstructured(v, mask_ratio_v)
|
417 |
+
|
418 |
+
# audio and visual stream, independent blocks
|
419 |
+
for blk in self.blocks_a:
|
420 |
+
a = blk(a)
|
421 |
+
|
422 |
+
for blk in self.blocks_v:
|
423 |
+
v = blk(v)
|
424 |
+
|
425 |
+
x = torch.cat((a, v), dim=1)
|
426 |
+
|
427 |
+
# unified stream, shared blocks_u, but independent normalization layers
|
428 |
+
for blk in self.blocks_u:
|
429 |
+
x = blk(x)
|
430 |
+
x = self.norm(x)
|
431 |
+
|
432 |
+
for blk in self.blocks_u:
|
433 |
+
ca = blk(a, 'a')
|
434 |
+
ca = self.norm_a(ca)
|
435 |
+
|
436 |
+
for blk in self.blocks_u:
|
437 |
+
cv = blk(v, 'v')
|
438 |
+
cv = self.norm_v(cv)
|
439 |
+
|
440 |
+
return x, mask_a, ids_restore_a, mask_v, ids_restore_v, ca, cv
|
441 |
+
|
442 |
+
def forward_decoder(self, x, mask_a, ids_restore_a, mask_v, ids_restore_v):
|
443 |
+
|
444 |
+
x = self.decoder_embed(x)
|
445 |
+
|
446 |
+
# append mask tokens to sequence
|
447 |
+
# mask_tokens_a in shape [B, #a_mask_token, mask_token_dim], get the number of masked samples from mask_a[0], which is the first example of the batch, all samples should have same number of masked tokens
|
448 |
+
mask_tokens_a = self.mask_token.repeat(x.shape[0], int(mask_a[0].sum()), 1)
|
449 |
+
a_ = torch.cat([x[:, :self.patch_embed_a.num_patches - int(mask_a[0].sum()), :], mask_tokens_a],
|
450 |
+
dim=1) # no cls token
|
451 |
+
a_ = torch.gather(a_, dim=1, index=ids_restore_a.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
452 |
+
|
453 |
+
# similar for the visual modality
|
454 |
+
mask_tokens_v = self.mask_token.repeat(x.shape[0], int(mask_v[0].sum()), 1)
|
455 |
+
v_ = torch.cat([x[:, self.patch_embed_a.num_patches - int(mask_a[0].sum()):, :], mask_tokens_v],
|
456 |
+
dim=1) # no cls token
|
457 |
+
v_ = torch.gather(v_, dim=1, index=ids_restore_v.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
|
458 |
+
|
459 |
+
# concatenate audio and visual tokens
|
460 |
+
x = torch.cat([a_, v_], dim=1)
|
461 |
+
|
462 |
+
decoder_pos_embed = torch.cat([self.decoder_pos_embed_a, self.decoder_pos_embed_v], dim=1)
|
463 |
+
x = x + decoder_pos_embed
|
464 |
+
|
465 |
+
# add modality indication tokens
|
466 |
+
x[:, 0:self.patch_embed_a.num_patches, :] = x[:, 0:self.patch_embed_a.num_patches, :] + self.decoder_modality_a
|
467 |
+
x[:, self.patch_embed_a.num_patches:, :] = x[:, self.patch_embed_a.num_patches:, :] + self.decoder_modality_v
|
468 |
+
|
469 |
+
# apply Transformer blocks
|
470 |
+
for blk in self.decoder_blocks:
|
471 |
+
x = blk(x)
|
472 |
+
x = self.decoder_norm(x)
|
473 |
+
|
474 |
+
# predictor projection
|
475 |
+
x_a = self.decoder_pred_a(x[:, :self.patch_embed_a.num_patches, :])
|
476 |
+
x_v = self.decoder_pred_v(x[:, self.patch_embed_a.num_patches:, :])
|
477 |
+
|
478 |
+
# return audio and video tokens
|
479 |
+
return x_a, x_v
|
480 |
+
|
481 |
+
def forward_contrastive(self, audio_rep, video_rep, bidirect_contrast=False):
|
482 |
+
# calculate nce loss for mean-visual representation and mean-audio representation
|
483 |
+
|
484 |
+
audio_rep = torch.nn.functional.normalize(audio_rep, dim=-1)
|
485 |
+
video_rep = torch.nn.functional.normalize(video_rep, dim=-1)
|
486 |
+
|
487 |
+
total = torch.mm(audio_rep, torch.transpose(video_rep, 0, 1)) / 0.05
|
488 |
+
|
489 |
+
# by default we use single directional
|
490 |
+
if bidirect_contrast == False:
|
491 |
+
nce = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
|
492 |
+
c_acc = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
|
493 |
+
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
|
494 |
+
return nce, c_acc
|
495 |
+
else:
|
496 |
+
nce_1 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total, dim=0)))
|
497 |
+
nce_2 = -torch.mean(torch.diag(torch.nn.functional.log_softmax(total.t(), dim=0)))
|
498 |
+
c_acc_1 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total, dim=0), dim=0),
|
499 |
+
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
|
500 |
+
c_acc_2 = torch.sum(torch.eq(torch.argmax(torch.nn.functional.softmax(total.t(), dim=0), dim=0),
|
501 |
+
torch.arange(0, total.shape[0], device=audio_rep.device))) / total.shape[0]
|
502 |
+
nce = (nce_1 + nce_2) / 2
|
503 |
+
c_acc = (c_acc_1 + c_acc_2) / 2
|
504 |
+
return nce, c_acc
|
505 |
+
|
506 |
+
def forward_mae_loss(self, input, pred, mask, modality):
|
507 |
+
if modality == 'a':
|
508 |
+
# for audio, need to adjust the shape
|
509 |
+
input = input.unsqueeze(1)
|
510 |
+
input = input.transpose(2, 3)
|
511 |
+
target = self.patchify(input, 1, int(input.shape[2] / self.patch_embed_a.patch_size[0]),
|
512 |
+
int(input.shape[3] / self.patch_embed_a.patch_size[1]), 16)
|
513 |
+
elif modality == 'v':
|
514 |
+
target = self.patchify(input, 3, int(input.shape[2] / self.patch_embed_v.patch_size[0]),
|
515 |
+
int(input.shape[3] / self.patch_embed_v.patch_size[1]), 16)
|
516 |
+
|
517 |
+
# patch-wise normalization might minorly improve the classification performance, but will make the model lose inpainting function
|
518 |
+
if self.norm_pix_loss:
|
519 |
+
mean = target.mean(dim=-1, keepdim=True)
|
520 |
+
var = target.var(dim=-1, keepdim=True)
|
521 |
+
target = (target - mean) / (var + 1.e-6) ** .5
|
522 |
+
|
523 |
+
loss = (pred - target) ** 2
|
524 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
525 |
+
|
526 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
527 |
+
return loss
|
528 |
+
|
529 |
+
def forward(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mae_loss_weight=1., contrast_loss_weight=0.01,
|
530 |
+
mask_mode='unstructured'):
|
531 |
+
# latent is used for reconstruction (mae), latent_c_{a,v} are used for contrastive learning
|
532 |
+
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
|
533 |
+
mask_ratio_a,
|
534 |
+
mask_ratio_v,
|
535 |
+
mask_mode=mask_mode)
|
536 |
+
# if mae loss is used
|
537 |
+
if mae_loss_weight != 0:
|
538 |
+
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v)
|
539 |
+
loss_mae_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
|
540 |
+
loss_mae_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
|
541 |
+
loss_mae = mae_loss_weight * (loss_mae_a + loss_mae_v)
|
542 |
+
else:
|
543 |
+
loss_mae_a, loss_mae_v, loss_mae = torch.tensor(0.0, device=audio.device), torch.tensor(0.0,
|
544 |
+
device=audio.device), torch.tensor(
|
545 |
+
0.0, device=audio.device)
|
546 |
+
|
547 |
+
# if contrastive loss is used
|
548 |
+
if contrast_loss_weight != 0:
|
549 |
+
# note this is single directional
|
550 |
+
loss_c, c_acc = self.forward_contrastive(latent_c_a.mean(dim=1), latent_c_v.mean(dim=1))
|
551 |
+
loss_c = contrast_loss_weight * loss_c
|
552 |
+
else:
|
553 |
+
loss_c, c_acc = torch.tensor(0.0, device=audio.device), torch.tensor(0.0, device=audio.device)
|
554 |
+
|
555 |
+
loss = loss_mae + loss_c
|
556 |
+
|
557 |
+
return loss, loss_mae, loss_mae_a, loss_mae_v, loss_c, mask_a, mask_v, c_acc
|
558 |
+
|
559 |
+
# used only for inpainting, ignore if inpainting is not of interest
|
560 |
+
def forward_inpaint(self, audio, imgs, mask_ratio_a=0.75, mask_ratio_v=0.75, mask_mode='unstructured'):
|
561 |
+
latent, mask_a, ids_restore_a, mask_v, ids_restore_v, latent_c_a, latent_c_v = self.forward_encoder(audio, imgs,
|
562 |
+
mask_ratio_a,
|
563 |
+
mask_ratio_v,
|
564 |
+
mask_mode=mask_mode)
|
565 |
+
pred_a, pred_v = self.forward_decoder(latent, mask_a, ids_restore_a, mask_v, ids_restore_v) # [N, L, p*p*3]
|
566 |
+
loss_pixel_a = self.forward_mae_loss(audio, pred_a, mask_a, 'a')
|
567 |
+
loss_pixel_v = self.forward_mae_loss(imgs, pred_v, mask_v, 'v')
|
568 |
+
return pred_a, pred_v, mask_a, mask_v, loss_pixel_a, loss_pixel_v
|
569 |
+
|
570 |
+
# used for retrieval, ignore if retrieval is not of interest
|
571 |
+
def forward_feat(self, a, v):
|
572 |
+
# embed patches
|
573 |
+
a = a.unsqueeze(1)
|
574 |
+
a = a.transpose(2, 3)
|
575 |
+
a = self.patch_embed_a(a)
|
576 |
+
a = a + self.pos_embed_a
|
577 |
+
a = a + self.modality_a
|
578 |
+
|
579 |
+
v = self.patch_embed_v(v)
|
580 |
+
v = v + self.pos_embed_v
|
581 |
+
v = v + self.modality_v
|
582 |
+
|
583 |
+
# the modality-specific stream
|
584 |
+
for blk in self.blocks_a:
|
585 |
+
a = blk(a)
|
586 |
+
|
587 |
+
for blk in self.blocks_v:
|
588 |
+
v = blk(v)
|
589 |
+
|
590 |
+
# use modality specific normalization,
|
591 |
+
for blk in self.blocks_u:
|
592 |
+
a = blk(a, 'a')
|
593 |
+
a = self.norm_a(a)
|
594 |
+
|
595 |
+
for blk in self.blocks_u:
|
596 |
+
v = blk(v, 'v')
|
597 |
+
v = self.norm_v(v)
|
598 |
+
return a, v
|
599 |
+
|
600 |
+
def forward_audio(self, a):
|
601 |
+
# embed patches
|
602 |
+
a = a.unsqueeze(1)
|
603 |
+
a = a.transpose(2, 3)
|
604 |
+
a = self.patch_embed_a(a)
|
605 |
+
a = a + self.pos_embed_a
|
606 |
+
a = a + self.modality_a
|
607 |
+
|
608 |
+
# the modality-specific stream
|
609 |
+
for blk in self.blocks_a:
|
610 |
+
a = blk(a)
|
611 |
+
|
612 |
+
# use modality specific normalization,
|
613 |
+
for blk in self.blocks_u:
|
614 |
+
a = blk(a, 'a')
|
615 |
+
a = self.norm_a(a)
|
616 |
+
|
617 |
+
return a.reshape(a.shape[0], 128 // 16, 1024 // 16, 768).permute(0, 3, 1, 2)
|
618 |
+
|
619 |
+
def forward_video(self, v):
|
620 |
+
v = self.patch_embed_v(v)
|
621 |
+
v = v + self.pos_embed_v
|
622 |
+
v = v + self.modality_v
|
623 |
+
|
624 |
+
for blk in self.blocks_v:
|
625 |
+
v = blk(v)
|
626 |
+
|
627 |
+
for blk in self.blocks_u:
|
628 |
+
v = blk(v, 'v')
|
629 |
+
v = self.norm_v(v)
|
630 |
+
return v.reshape(v.shape[0], 224 // 16, 224 // 16, 768).permute(0, 3, 1, 2)
|
631 |
+
|
632 |
+
|
633 |
+
# the finetuned CAV-MAE model
|
634 |
+
class CAVMAEFT(nn.Module):
|
635 |
+
def __init__(self, label_dim, img_size=224, audio_length=1024, patch_size=16, in_chans=3,
|
636 |
+
embed_dim=768, modality_specific_depth=11, num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm,
|
637 |
+
norm_pix_loss=False, tr_pos=True):
|
638 |
+
super().__init__()
|
639 |
+
timm.models.vision_transformer.Block = Block
|
640 |
+
print('Use norm_pix_loss: ', norm_pix_loss)
|
641 |
+
|
642 |
+
timm.models.vision_transformer.PatchEmbed = PatchEmbed
|
643 |
+
timm.models.vision_transformer.Block = Block
|
644 |
+
|
645 |
+
self.patch_embed_a = PatchEmbed(img_size, patch_size, 1, embed_dim)
|
646 |
+
self.patch_embed_v = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
|
647 |
+
|
648 |
+
self.patch_embed_a.num_patches = int(audio_length * 128 / 256)
|
649 |
+
print('Number of Audio Patches: {:d}, Visual Patches: {:d}'.format(self.patch_embed_a.num_patches,
|
650 |
+
self.patch_embed_v.num_patches))
|
651 |
+
|
652 |
+
self.modality_a = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
653 |
+
self.modality_v = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
654 |
+
|
655 |
+
self.pos_embed_a = nn.Parameter(torch.zeros(1, self.patch_embed_a.num_patches, embed_dim),
|
656 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
657 |
+
self.pos_embed_v = nn.Parameter(torch.zeros(1, self.patch_embed_v.num_patches, embed_dim),
|
658 |
+
requires_grad=tr_pos) # fixed sin-cos embedding
|
659 |
+
|
660 |
+
self.blocks_a = nn.ModuleList(
|
661 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
662 |
+
range(modality_specific_depth)])
|
663 |
+
self.blocks_v = nn.ModuleList(
|
664 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
665 |
+
range(modality_specific_depth)])
|
666 |
+
self.blocks_u = nn.ModuleList(
|
667 |
+
[Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) for i in
|
668 |
+
range(12 - modality_specific_depth)])
|
669 |
+
|
670 |
+
self.norm_a = norm_layer(embed_dim)
|
671 |
+
self.norm_v = norm_layer(embed_dim)
|
672 |
+
self.norm = norm_layer(embed_dim)
|
673 |
+
|
674 |
+
self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, label_dim))
|
675 |
+
|
676 |
+
self.initialize_weights()
|
677 |
+
|
678 |
+
print('Audio Positional Embedding Shape:', self.pos_embed_a.shape)
|
679 |
+
print('Visual Positional Embedding Shape:', self.pos_embed_v.shape)
|
680 |
+
|
681 |
+
def get_patch_num(self, input_shape, stride):
|
682 |
+
test_input = torch.zeros(1, 1, input_shape[0], input_shape[1])
|
683 |
+
test_proj = torch.nn.Conv2d(1, 4, kernel_size=(16, 16), stride=(stride, stride))
|
684 |
+
test_output = test_proj(test_input)
|
685 |
+
print(test_output.shape)
|
686 |
+
return test_output.shape[2], test_output[3], test_output[2] * test_output[2]
|
687 |
+
|
688 |
+
def initialize_weights(self):
|
689 |
+
pos_embed_a = get_2d_sincos_pos_embed(self.pos_embed_a.shape[-1], 8, int(self.patch_embed_a.num_patches / 8),
|
690 |
+
cls_token=False)
|
691 |
+
self.pos_embed_a.data.copy_(torch.from_numpy(pos_embed_a).float().unsqueeze(0))
|
692 |
+
|
693 |
+
pos_embed_v = get_2d_sincos_pos_embed(self.pos_embed_v.shape[-1], int(self.patch_embed_v.num_patches ** .5),
|
694 |
+
int(self.patch_embed_v.num_patches ** .5), cls_token=False)
|
695 |
+
self.pos_embed_v.data.copy_(torch.from_numpy(pos_embed_v).float().unsqueeze(0))
|
696 |
+
|
697 |
+
w = self.patch_embed_a.proj.weight.data
|
698 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
699 |
+
w = self.patch_embed_v.proj.weight.data
|
700 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
701 |
+
|
702 |
+
torch.nn.init.normal_(self.modality_a, std=.02)
|
703 |
+
torch.nn.init.normal_(self.modality_v, std=.02)
|
704 |
+
|
705 |
+
self.apply(self._init_weights)
|
706 |
+
|
707 |
+
def _init_weights(self, m):
|
708 |
+
if isinstance(m, nn.Linear):
|
709 |
+
# we use xavier_uniform following official JAX ViT:
|
710 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
711 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
712 |
+
nn.init.constant_(m.bias, 0)
|
713 |
+
elif isinstance(m, nn.LayerNorm):
|
714 |
+
nn.init.constant_(m.bias, 0)
|
715 |
+
nn.init.constant_(m.weight, 1.0)
|
716 |
+
|
717 |
+
def forward(self, a, v, mode):
|
718 |
+
# multi-modal fine-tuning, our default method for fine-tuning
|
719 |
+
if mode == 'multimodal':
|
720 |
+
a = a.unsqueeze(1)
|
721 |
+
a = a.transpose(2, 3)
|
722 |
+
a = self.patch_embed_a(a)
|
723 |
+
a = a + self.pos_embed_a
|
724 |
+
a = a + self.modality_a
|
725 |
+
|
726 |
+
v = self.patch_embed_v(v)
|
727 |
+
v = v + self.pos_embed_v
|
728 |
+
v = v + self.modality_v
|
729 |
+
|
730 |
+
for blk in self.blocks_a:
|
731 |
+
a = blk(a)
|
732 |
+
|
733 |
+
for blk in self.blocks_v:
|
734 |
+
v = blk(v)
|
735 |
+
|
736 |
+
x = torch.cat((a, v), dim=1)
|
737 |
+
|
738 |
+
for blk in self.blocks_u:
|
739 |
+
x = blk(x)
|
740 |
+
x = self.norm(x)
|
741 |
+
|
742 |
+
x = x.mean(dim=1)
|
743 |
+
x = self.mlp_head(x)
|
744 |
+
return x
|
745 |
+
|
746 |
+
# finetune with only audio (and inference with only audio when the model is finetuned with only audio)
|
747 |
+
elif mode == 'audioonly':
|
748 |
+
a = a.unsqueeze(1)
|
749 |
+
a = a.transpose(2, 3)
|
750 |
+
a = self.patch_embed_a(a)
|
751 |
+
a = a + self.pos_embed_a
|
752 |
+
a = a + self.modality_a
|
753 |
+
|
754 |
+
for blk in self.blocks_a:
|
755 |
+
a = blk(a)
|
756 |
+
|
757 |
+
# note here uses the 'a' normalization, it is used in both training and inference, so it is fine
|
758 |
+
for blk in self.blocks_u:
|
759 |
+
a = blk(a, 'a')
|
760 |
+
a = self.norm_a(a)
|
761 |
+
x = a.mean(dim=1)
|
762 |
+
x = self.mlp_head(x)
|
763 |
+
return x
|
764 |
+
|
765 |
+
# finetune with only image (and inference with only audio when the model is finetuned with only image)
|
766 |
+
elif mode == 'videoonly':
|
767 |
+
v = self.patch_embed_v(v)
|
768 |
+
v = v + self.pos_embed_v
|
769 |
+
v = v + self.modality_v
|
770 |
+
|
771 |
+
for blk in self.blocks_v:
|
772 |
+
v = blk(v)
|
773 |
+
|
774 |
+
# note here uses the 'v' normalization, it is used in both training and inference, so it is fine
|
775 |
+
for blk in self.blocks_u:
|
776 |
+
v = blk(v, 'v')
|
777 |
+
v = self.norm_v(v)
|
778 |
+
x = v.mean(dim=1)
|
779 |
+
x = self.mlp_head(x)
|
780 |
+
return x
|
781 |
+
|
782 |
+
# used in case that the model is finetuned with both modality, but in inference only audio is given
|
783 |
+
elif mode == 'missingaudioonly':
|
784 |
+
a = a.unsqueeze(1)
|
785 |
+
a = a.transpose(2, 3)
|
786 |
+
a = self.patch_embed_a(a)
|
787 |
+
a = a + self.pos_embed_a
|
788 |
+
a = a + self.modality_a
|
789 |
+
|
790 |
+
for blk in self.blocks_a:
|
791 |
+
a = blk(a)
|
792 |
+
|
793 |
+
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
|
794 |
+
u = a
|
795 |
+
for blk in self.blocks_u:
|
796 |
+
u = blk(u) # note here use unified normalization
|
797 |
+
u = self.norm(u)
|
798 |
+
u = u.mean(dim=1)
|
799 |
+
|
800 |
+
for blk in self.blocks_u:
|
801 |
+
a = blk(a, 'a') # note here use modality-specific normalization
|
802 |
+
a = self.norm_a(a)
|
803 |
+
a = a.mean(dim=1)
|
804 |
+
|
805 |
+
# average the output of the two forward passes
|
806 |
+
x = (u + a) / 2
|
807 |
+
x = self.mlp_head(x)
|
808 |
+
return x
|
809 |
+
|
810 |
+
# used in case that the model is fine-tuned with both modality, but in inference only image is given
|
811 |
+
elif mode == 'missingvideoonly':
|
812 |
+
v = self.patch_embed_v(v)
|
813 |
+
v = v + self.pos_embed_v
|
814 |
+
v = v + self.modality_v
|
815 |
+
|
816 |
+
for blk in self.blocks_v:
|
817 |
+
v = blk(v)
|
818 |
+
|
819 |
+
# two forward passes to the block_u, one with modality-specific normalization, another with unified normalization
|
820 |
+
u = v
|
821 |
+
for blk in self.blocks_u:
|
822 |
+
u = blk(u) # note here use unified normalization
|
823 |
+
u = self.norm(u)
|
824 |
+
u = u.mean(dim=1)
|
825 |
+
|
826 |
+
for blk in self.blocks_u:
|
827 |
+
v = blk(v, 'v') # note here use modality-specific normalization
|
828 |
+
v = self.norm_v(v)
|
829 |
+
v = v.mean(dim=1)
|
830 |
+
|
831 |
+
# average the output of the two forward passes
|
832 |
+
x = (u + v) / 2
|
833 |
+
x = self.mlp_head(x)
|
834 |
+
return x
|
835 |
+
|
836 |
+
# for retrieval
|
837 |
+
def forward_feat(self, a, v, mode='av'):
|
838 |
+
# return both audio and visual
|
839 |
+
if mode == 'av':
|
840 |
+
a = a.unsqueeze(1)
|
841 |
+
a = a.transpose(2, 3)
|
842 |
+
a = self.patch_embed_a(a)
|
843 |
+
a = a + self.pos_embed_a
|
844 |
+
a = a + self.modality_a
|
845 |
+
|
846 |
+
v = self.patch_embed_v(v)
|
847 |
+
v = v + self.pos_embed_v
|
848 |
+
v = v + self.modality_v
|
849 |
+
|
850 |
+
for blk in self.blocks_a:
|
851 |
+
a = blk(a)
|
852 |
+
|
853 |
+
for blk in self.blocks_v:
|
854 |
+
v = blk(v)
|
855 |
+
|
856 |
+
for blk in self.blocks_u:
|
857 |
+
a = blk(a, 'a')
|
858 |
+
a = self.norm_a(a)
|
859 |
+
|
860 |
+
for blk in self.blocks_u:
|
861 |
+
v = blk(v, 'v')
|
862 |
+
|
863 |
+
v = self.norm_v(v)
|
864 |
+
return a, v
|
865 |
+
|
866 |
+
# return only audio
|
867 |
+
if mode == 'a':
|
868 |
+
a = a.unsqueeze(1)
|
869 |
+
a = a.transpose(2, 3)
|
870 |
+
a = self.patch_embed_a(a)
|
871 |
+
a = a + self.pos_embed_a
|
872 |
+
a = a + self.modality_a
|
873 |
+
|
874 |
+
for blk in self.blocks_a:
|
875 |
+
a = blk(a)
|
876 |
+
|
877 |
+
for blk in self.blocks_u:
|
878 |
+
a = blk(a, 'a')
|
879 |
+
|
880 |
+
a = self.norm_a(a)
|
881 |
+
return a
|
882 |
+
|
883 |
+
|
884 |
+
def _wav2fbank(filename):
|
885 |
+
waveform, sr = torchaudio.load(filename)
|
886 |
+
waveform = torchaudio.functional.resample(
|
887 |
+
waveform, orig_freq=sr, new_freq=16000
|
888 |
+
)
|
889 |
+
|
890 |
+
waveform = waveform - waveform.mean()
|
891 |
+
waveform
|
892 |
+
print(sr)
|
893 |
+
|
894 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
895 |
+
waveform,
|
896 |
+
htk_compat=True,
|
897 |
+
sample_frequency=sr,
|
898 |
+
use_energy=False,
|
899 |
+
window_type='hanning',
|
900 |
+
num_mel_bins=128,
|
901 |
+
dither=0.0,
|
902 |
+
frame_shift=10)
|
903 |
+
|
904 |
+
target_length = 1024
|
905 |
+
n_frames = fbank.shape[0]
|
906 |
+
|
907 |
+
p = target_length - n_frames
|
908 |
+
|
909 |
+
# cut and pad
|
910 |
+
if p > 0:
|
911 |
+
m = torch.nn.ZeroPad2d((0, 0, 0, p))
|
912 |
+
fbank = m(fbank)
|
913 |
+
elif p < 0:
|
914 |
+
fbank = fbank[0:target_length, :]
|
915 |
+
|
916 |
+
return fbank
|
917 |
+
|
918 |
+
|
919 |
+
def pca(image_feats_list, dim=3, fit_pca=None):
|
920 |
+
from sklearn.decomposition import PCA
|
921 |
+
|
922 |
+
device = image_feats_list[0].device
|
923 |
+
|
924 |
+
def flatten(tensor, target_size=None):
|
925 |
+
if target_size is not None and fit_pca is None:
|
926 |
+
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
|
927 |
+
B, C, H, W = tensor.shape
|
928 |
+
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
|
929 |
+
|
930 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
931 |
+
target_size = image_feats_list[0].shape[2]
|
932 |
+
else:
|
933 |
+
target_size = None
|
934 |
+
|
935 |
+
flattened_feats = []
|
936 |
+
for feats in image_feats_list:
|
937 |
+
flattened_feats.append(flatten(feats, target_size))
|
938 |
+
x = torch.cat(flattened_feats, dim=0)
|
939 |
+
|
940 |
+
if fit_pca is None:
|
941 |
+
fit_pca = PCA(n_components=dim).fit(x)
|
942 |
+
|
943 |
+
reduced_feats = []
|
944 |
+
for feats in image_feats_list:
|
945 |
+
x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
|
946 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
947 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
948 |
+
B, C, H, W = feats.shape
|
949 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
|
950 |
+
|
951 |
+
return reduced_feats, fit_pca
|
952 |
+
|
953 |
+
|
954 |
+
class CAVMAEAudioFeaturizer(nn.Module):
|
955 |
+
|
956 |
+
def __init__(self, output_path, model_name="base", model=None):
|
957 |
+
super().__init__()
|
958 |
+
if model is not None:
|
959 |
+
self.model = model
|
960 |
+
else:
|
961 |
+
if model_name == "base":
|
962 |
+
model_path = os.path.join(output_path, 'models/audio_model.21.pth')
|
963 |
+
else:
|
964 |
+
raise ValueError(f"Unknown model type {model_name}")
|
965 |
+
|
966 |
+
audio_model = CAVMAE(
|
967 |
+
audio_length=1024,
|
968 |
+
modality_specific_depth=11,
|
969 |
+
norm_pix_loss=True,
|
970 |
+
tr_pos=False)
|
971 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
972 |
+
mdl_weight = torch.load(model_path, map_location=device)
|
973 |
+
audio_model = torch.nn.DataParallel(audio_model)
|
974 |
+
audio_model.load_state_dict(mdl_weight, strict=True)
|
975 |
+
self.model = audio_model.module.cuda()
|
976 |
+
|
977 |
+
def forward(self, audio, include_cls):
|
978 |
+
cls_token = None
|
979 |
+
patch_tokens = self.model.forward_audio(audio.squeeze(1))
|
980 |
+
|
981 |
+
if include_cls:
|
982 |
+
return patch_tokens, cls_token
|
983 |
+
else:
|
984 |
+
return patch_tokens
|
985 |
+
|
986 |
+
|
987 |
+
class CAVMAEImageFeaturizer(nn.Module):
|
988 |
+
|
989 |
+
def __init__(self, output_path, model=None, model_name="base"):
|
990 |
+
super().__init__()
|
991 |
+
if model is not None:
|
992 |
+
self.model: CAVMAE = model
|
993 |
+
else:
|
994 |
+
if model_name == "base":
|
995 |
+
model_path = os.path.join(output_path, 'models/audio_model.21.pth')
|
996 |
+
else:
|
997 |
+
raise ValueError(f"Unknown model type {model_name}")
|
998 |
+
|
999 |
+
audio_model = CAVMAE(
|
1000 |
+
audio_length=1024,
|
1001 |
+
modality_specific_depth=11,
|
1002 |
+
norm_pix_loss=True,
|
1003 |
+
tr_pos=False)
|
1004 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1005 |
+
mdl_weight = torch.load(model_path, map_location=device)
|
1006 |
+
audio_model = torch.nn.DataParallel(audio_model)
|
1007 |
+
audio_model.load_state_dict(mdl_weight, strict=True)
|
1008 |
+
self.model: CAVMAE = audio_model.module.cuda()
|
1009 |
+
|
1010 |
+
def forward(self, image, include_cls):
|
1011 |
+
cls_token = None
|
1012 |
+
patch_tokens = self.model.forward_video(image)
|
1013 |
+
|
1014 |
+
if include_cls:
|
1015 |
+
return patch_tokens, cls_token
|
1016 |
+
else:
|
1017 |
+
return patch_tokens
|
1018 |
+
|
1019 |
+
|
1020 |
+
if __name__ == "__main__":
|
1021 |
+
model_path = os.path.join("../../", 'models/audio_model.21.pth')
|
1022 |
+
audio_model = CAVMAE(
|
1023 |
+
audio_length=1024,
|
1024 |
+
modality_specific_depth=11,
|
1025 |
+
norm_pix_loss=True,
|
1026 |
+
tr_pos=False)
|
1027 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
1028 |
+
mdl_weight = torch.load(model_path, map_location=device)
|
1029 |
+
audio_model = torch.nn.DataParallel(audio_model)
|
1030 |
+
audio_model.load_state_dict(mdl_weight, strict=True)
|
1031 |
+
model: CAVMAE = audio_model.module.cuda()
|
1032 |
+
|
1033 |
+
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
|
1034 |
+
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
|
1035 |
+
|
1036 |
+
images = []
|
1037 |
+
audios = []
|
1038 |
+
|
1039 |
+
for image_path in image_paths:
|
1040 |
+
image = Image.open(image_path).convert("RGB")
|
1041 |
+
preprocess = T.Compose([
|
1042 |
+
T.Resize(224, interpolation=Image.BICUBIC),
|
1043 |
+
T.CenterCrop(224),
|
1044 |
+
T.ToTensor(),
|
1045 |
+
T.Normalize(
|
1046 |
+
mean=[0.4850, 0.4560, 0.4060],
|
1047 |
+
std=[0.2290, 0.2240, 0.2250]
|
1048 |
+
)])
|
1049 |
+
images.append(preprocess(image).unsqueeze(0).cuda())
|
1050 |
+
|
1051 |
+
for audio_path in audio_paths:
|
1052 |
+
a = _wav2fbank(audio_path).cuda().unsqueeze(0)
|
1053 |
+
a = (a + 5.081) / (4.4849)
|
1054 |
+
audios.append(a)
|
1055 |
+
|
1056 |
+
audio_feats, image_feats = model.forward_feat(
|
1057 |
+
torch.cat(audios, dim=0), torch.cat(images, dim=0))
|
1058 |
+
|
1059 |
+
audio_feats = F.normalize(audio_feats.mean(1), dim=1)
|
1060 |
+
image_feats = F.normalize(image_feats.mean(1), dim=1)
|
1061 |
+
|
1062 |
+
sims = torch.einsum("bc,dc->bd", image_feats, audio_feats)
|
1063 |
+
print(sims)
|
1064 |
+
|
1065 |
+
print("here")
|
1066 |
+
|
1067 |
+
# a_feat = F.normalize(a_feat, dim=1)
|
1068 |
+
# v_feat = F.normalize(v_feat, dim=1)
|
1069 |
+
|
1070 |
+
# [red_v_feat, red_a_feat], fit_pca = pca([v_feat, a_feat])
|
1071 |
+
#
|
1072 |
+
# [red_v_feat], fit_pca = pca([v_feat])
|
1073 |
+
# [red_a_feat], fit_pca = pca([a_feat])
|
1074 |
+
#
|
1075 |
+
# import matplotlib.pyplot as plt
|
1076 |
+
#
|
1077 |
+
# fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 5))
|
1078 |
+
# ax[0].imshow(red_v_feat[0].permute(1, 2, 0).cpu())
|
1079 |
+
# ax[1].imshow(red_a_feat[0].permute(1, 2, 0).cpu())
|
1080 |
+
# plt.tight_layout()
|
1081 |
+
# plt.show()
|
1082 |
+
# print("here")
|
DenseAV/denseav/featurizers/CLIP.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class CLIPFeaturizer(nn.Module):
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
super().__init__()
|
10 |
+
self.model, self.preprocess = clip.load("ViT-B/16", device="cpu")
|
11 |
+
self.model.eval().cuda()
|
12 |
+
self.config = {}
|
13 |
+
|
14 |
+
def get_cls_token(self, img):
|
15 |
+
return self.model.encode_image(img).to(torch.float32)
|
16 |
+
|
17 |
+
def forward(self, img, include_cls):
|
18 |
+
features = self.model.get_visual_features(img, include_cls)
|
19 |
+
new_features = []
|
20 |
+
for i in range(2):
|
21 |
+
t = features[i]
|
22 |
+
if isinstance(t, torch.Tensor):
|
23 |
+
new_features.append(t.to(torch.float32))
|
24 |
+
else:
|
25 |
+
new_features.append(t)
|
26 |
+
|
27 |
+
return new_features
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
import torchvision.transforms as T
|
32 |
+
from PIL import Image
|
33 |
+
from shared import norm, crop_to_divisor
|
34 |
+
|
35 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
+
|
37 |
+
image = Image.open("../samples/lex1.jpg")
|
38 |
+
load_size = 224 # * 3
|
39 |
+
transform = T.Compose([
|
40 |
+
T.Resize(load_size, Image.BILINEAR),
|
41 |
+
# T.CenterCrop(load_size),
|
42 |
+
T.ToTensor(),
|
43 |
+
lambda x: crop_to_divisor(x, 16),
|
44 |
+
norm])
|
45 |
+
|
46 |
+
model = CLIPFeaturizer().cuda()
|
47 |
+
|
48 |
+
results = model(transform(image).cuda().unsqueeze(0))
|
49 |
+
|
50 |
+
print(clip.available_models())
|
DenseAV/denseav/featurizers/DAVENet.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: David Harwath
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional
|
5 |
+
import torch.nn.functional
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.utils.model_zoo as model_zoo
|
8 |
+
import torchvision.models as imagemodels
|
9 |
+
|
10 |
+
|
11 |
+
class Davenet(nn.Module):
|
12 |
+
def __init__(self, embedding_dim=1024):
|
13 |
+
super(Davenet, self).__init__()
|
14 |
+
self.embedding_dim = embedding_dim
|
15 |
+
self.batchnorm1 = nn.BatchNorm2d(1)
|
16 |
+
self.conv1 = nn.Conv2d(1, 128, kernel_size=(40, 1), stride=(1, 1), padding=(0, 0))
|
17 |
+
self.conv2 = nn.Conv2d(128, 256, kernel_size=(1, 11), stride=(1, 1), padding=(0, 5))
|
18 |
+
self.conv3 = nn.Conv2d(256, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
|
19 |
+
self.conv4 = nn.Conv2d(512, 512, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
|
20 |
+
self.conv5 = nn.Conv2d(512, embedding_dim, kernel_size=(1, 17), stride=(1, 1), padding=(0, 8))
|
21 |
+
self.pool = nn.MaxPool2d(kernel_size=(1, 3), stride=(1, 2), padding=(0, 1))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
if x.dim() == 3:
|
25 |
+
x = x.unsqueeze(1)
|
26 |
+
x = self.batchnorm1(x)
|
27 |
+
x = F.relu(self.conv1(x))
|
28 |
+
x = F.relu(self.conv2(x))
|
29 |
+
x = self.pool(x)
|
30 |
+
x = F.relu(self.conv3(x))
|
31 |
+
x = self.pool(x)
|
32 |
+
x = F.relu(self.conv4(x))
|
33 |
+
x = self.pool(x)
|
34 |
+
x = F.relu(self.conv5(x))
|
35 |
+
x = self.pool(x)
|
36 |
+
x = x.squeeze(2)
|
37 |
+
return x
|
38 |
+
|
39 |
+
|
40 |
+
class Resnet18(imagemodels.ResNet):
|
41 |
+
def __init__(self, embedding_dim=1024, pretrained=False):
|
42 |
+
super(Resnet18, self).__init__(imagemodels.resnet.BasicBlock, [2, 2, 2, 2])
|
43 |
+
if pretrained:
|
44 |
+
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet18']))
|
45 |
+
self.avgpool = None
|
46 |
+
self.fc = None
|
47 |
+
self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
|
48 |
+
self.embedding_dim = embedding_dim
|
49 |
+
self.pretrained = pretrained
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
x = self.conv1(x)
|
53 |
+
x = self.bn1(x)
|
54 |
+
x = self.relu(x)
|
55 |
+
x = self.maxpool(x)
|
56 |
+
x = self.layer1(x)
|
57 |
+
x = self.layer2(x)
|
58 |
+
x = self.layer3(x)
|
59 |
+
x = self.layer4(x)
|
60 |
+
x = self.embedder(x)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
class Resnet34(imagemodels.ResNet):
|
65 |
+
def __init__(self, embedding_dim=1024, pretrained=False):
|
66 |
+
super(Resnet34, self).__init__(imagemodels.resnet.BasicBlock, [3, 4, 6, 3])
|
67 |
+
if pretrained:
|
68 |
+
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet34']))
|
69 |
+
self.avgpool = None
|
70 |
+
self.fc = None
|
71 |
+
self.embedder = nn.Conv2d(512, embedding_dim, kernel_size=1, stride=1, padding=0)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
x = self.conv1(x)
|
75 |
+
x = self.bn1(x)
|
76 |
+
x = self.relu(x)
|
77 |
+
x = self.maxpool(x)
|
78 |
+
x = self.layer1(x)
|
79 |
+
x = self.layer2(x)
|
80 |
+
x = self.layer3(x)
|
81 |
+
x = self.layer4(x)
|
82 |
+
x = self.embedder(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class Resnet50(imagemodels.ResNet):
|
87 |
+
def __init__(self, embedding_dim=1024, pretrained=False):
|
88 |
+
super(Resnet50, self).__init__(imagemodels.resnet.Bottleneck, [3, 4, 6, 3])
|
89 |
+
if pretrained:
|
90 |
+
self.load_state_dict(model_zoo.load_url(imagemodels.resnet.model_urls['resnet50']))
|
91 |
+
self.avgpool = None
|
92 |
+
self.fc = None
|
93 |
+
self.embedder = nn.Conv2d(2048, embedding_dim, kernel_size=1, stride=1, padding=0)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
x = self.conv1(x)
|
97 |
+
x = self.bn1(x)
|
98 |
+
x = self.relu(x)
|
99 |
+
x = self.maxpool(x)
|
100 |
+
x = self.layer1(x)
|
101 |
+
x = self.layer2(x)
|
102 |
+
x = self.layer3(x)
|
103 |
+
x = self.layer4(x)
|
104 |
+
x = self.embedder(x)
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
class VGG16(nn.Module):
|
109 |
+
def __init__(self, embedding_dim=1024, pretrained=False):
|
110 |
+
super(VGG16, self).__init__()
|
111 |
+
seed_model = imagemodels.__dict__['vgg16'](pretrained=pretrained).features
|
112 |
+
seed_model = nn.Sequential(*list(seed_model.children())[:-1]) # remove final maxpool
|
113 |
+
last_layer_index = len(list(seed_model.children()))
|
114 |
+
seed_model.add_module(str(last_layer_index),
|
115 |
+
nn.Conv2d(512, embedding_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
|
116 |
+
self.image_model = seed_model
|
117 |
+
|
118 |
+
def forward(self, x):
|
119 |
+
x = self.image_model(x)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
def prep(dict):
|
124 |
+
return {k.replace("module.", ""): v for k, v in dict.items()}
|
125 |
+
|
126 |
+
|
127 |
+
class DavenetAudioFeaturizer(nn.Module):
|
128 |
+
|
129 |
+
def __init__(self):
|
130 |
+
super().__init__()
|
131 |
+
self.audio_model = Davenet()
|
132 |
+
self.audio_model.load_state_dict(prep(torch.load("../models/davenet_pt_audio.pth")))
|
133 |
+
|
134 |
+
def forward(self, audio, include_cls):
|
135 |
+
patch_tokens = self.audio_model(audio).unsqueeze(2)
|
136 |
+
|
137 |
+
if include_cls:
|
138 |
+
return patch_tokens, None
|
139 |
+
else:
|
140 |
+
return patch_tokens
|
141 |
+
|
142 |
+
def get_last_params(self):
|
143 |
+
return []
|
144 |
+
|
145 |
+
|
146 |
+
class DavenetImageFeaturizer(nn.Module):
|
147 |
+
|
148 |
+
def __init__(self):
|
149 |
+
super().__init__()
|
150 |
+
self.image_model = VGG16()
|
151 |
+
self.image_model.load_state_dict(prep(torch.load("../models/davenet_pt_image.pth")))
|
152 |
+
|
153 |
+
def forward(self, image, include_cls):
|
154 |
+
patch_tokens = self.image_model(image)
|
155 |
+
|
156 |
+
if include_cls:
|
157 |
+
return patch_tokens, None
|
158 |
+
else:
|
159 |
+
return patch_tokens
|
160 |
+
|
161 |
+
def get_last_params(self):
|
162 |
+
return []
|
DenseAV/denseav/featurizers/DINO.py
ADDED
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from functools import partial
|
4 |
+
|
5 |
+
import timm
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
eps = 1e-4
|
10 |
+
|
11 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
12 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
13 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
14 |
+
def norm_cdf(x):
|
15 |
+
# Computes standard normal cumulative distribution function
|
16 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
17 |
+
|
18 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
19 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
20 |
+
"The distribution of values may be incorrect.",
|
21 |
+
stacklevel=2)
|
22 |
+
|
23 |
+
with torch.no_grad():
|
24 |
+
# Values are generated by using a truncated uniform distribution and
|
25 |
+
# then using the inverse CDF for the normal distribution.
|
26 |
+
# Get upper and lower cdf values
|
27 |
+
l = norm_cdf((a - mean) / std)
|
28 |
+
u = norm_cdf((b - mean) / std)
|
29 |
+
|
30 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
31 |
+
# [2l-1, 2u-1].
|
32 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
33 |
+
|
34 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
35 |
+
# standard normal
|
36 |
+
tensor.erfinv_()
|
37 |
+
|
38 |
+
# Transform to proper mean, std
|
39 |
+
tensor.mul_(std * math.sqrt(2.))
|
40 |
+
tensor.add_(mean)
|
41 |
+
|
42 |
+
# Clamp to ensure it's in the proper range
|
43 |
+
tensor.clamp_(min=a, max=b)
|
44 |
+
return tensor
|
45 |
+
|
46 |
+
|
47 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
48 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
49 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
54 |
+
if drop_prob == 0. or not training:
|
55 |
+
return x
|
56 |
+
keep_prob = 1 - drop_prob
|
57 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
58 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
59 |
+
random_tensor.floor_() # binarize
|
60 |
+
output = x.div(keep_prob) * random_tensor
|
61 |
+
return output
|
62 |
+
|
63 |
+
|
64 |
+
class DropPath(nn.Module):
|
65 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
66 |
+
"""
|
67 |
+
|
68 |
+
def __init__(self, drop_prob=None):
|
69 |
+
super(DropPath, self).__init__()
|
70 |
+
self.drop_prob = drop_prob
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
return drop_path(x, self.drop_prob, self.training)
|
74 |
+
|
75 |
+
|
76 |
+
class Mlp(nn.Module):
|
77 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
78 |
+
super().__init__()
|
79 |
+
out_features = out_features or in_features
|
80 |
+
hidden_features = hidden_features or in_features
|
81 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
82 |
+
self.act = act_layer()
|
83 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
84 |
+
self.drop = nn.Dropout(drop)
|
85 |
+
|
86 |
+
def forward(self, x):
|
87 |
+
x = self.fc1(x)
|
88 |
+
x = self.act(x)
|
89 |
+
x = self.drop(x)
|
90 |
+
x = self.fc2(x)
|
91 |
+
x = self.drop(x)
|
92 |
+
return x
|
93 |
+
|
94 |
+
|
95 |
+
class Attention(nn.Module):
|
96 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
97 |
+
super().__init__()
|
98 |
+
self.num_heads = num_heads
|
99 |
+
head_dim = dim // num_heads
|
100 |
+
self.scale = qk_scale or head_dim ** -0.5
|
101 |
+
|
102 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
103 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
104 |
+
self.proj = nn.Linear(dim, dim)
|
105 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
106 |
+
|
107 |
+
def forward(self, x, return_qkv=False):
|
108 |
+
B, N, C = x.shape
|
109 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
110 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
111 |
+
|
112 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
113 |
+
attn = attn.softmax(dim=-1)
|
114 |
+
attn = self.attn_drop(attn)
|
115 |
+
|
116 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
117 |
+
x = self.proj(x)
|
118 |
+
x = self.proj_drop(x)
|
119 |
+
return x, attn, qkv
|
120 |
+
|
121 |
+
|
122 |
+
class Block(nn.Module):
|
123 |
+
def __init__(self, dim,
|
124 |
+
num_heads,
|
125 |
+
mlp_ratio=4.,
|
126 |
+
qkv_bias=False,
|
127 |
+
qk_scale=None,
|
128 |
+
drop=0.,
|
129 |
+
attn_drop=0.,
|
130 |
+
drop_path=0.,
|
131 |
+
act_layer=nn.GELU,
|
132 |
+
norm_layer=nn.LayerNorm):
|
133 |
+
super().__init__()
|
134 |
+
self.norm1 = norm_layer(dim)
|
135 |
+
self.attn = Attention(
|
136 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
137 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
138 |
+
self.norm2 = norm_layer(dim)
|
139 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
140 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
141 |
+
|
142 |
+
def forward(self, x, return_attention=False, return_qkv=False):
|
143 |
+
y, attn, qkv = self.attn(self.norm1(x))
|
144 |
+
if return_attention:
|
145 |
+
return attn
|
146 |
+
x = x + self.drop_path(y)
|
147 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
148 |
+
if return_qkv:
|
149 |
+
return x, attn, qkv
|
150 |
+
return x
|
151 |
+
|
152 |
+
|
153 |
+
class PatchEmbed(nn.Module):
|
154 |
+
""" Image to Patch Embedding
|
155 |
+
"""
|
156 |
+
|
157 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
158 |
+
super().__init__()
|
159 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
160 |
+
self.img_size = img_size
|
161 |
+
self.patch_size = patch_size
|
162 |
+
self.num_patches = num_patches
|
163 |
+
|
164 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
165 |
+
|
166 |
+
def forward(self, x):
|
167 |
+
B, C, H, W = x.shape
|
168 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
169 |
+
return x
|
170 |
+
|
171 |
+
|
172 |
+
class VisionTransformer(nn.Module):
|
173 |
+
""" Vision Transformer """
|
174 |
+
|
175 |
+
def __init__(self,
|
176 |
+
img_size=[224],
|
177 |
+
patch_size=16,
|
178 |
+
in_chans=3,
|
179 |
+
num_classes=0,
|
180 |
+
embed_dim=768,
|
181 |
+
depth=12,
|
182 |
+
num_heads=12,
|
183 |
+
mlp_ratio=4.,
|
184 |
+
qkv_bias=False,
|
185 |
+
qk_scale=None,
|
186 |
+
drop_rate=0.,
|
187 |
+
attn_drop_rate=0.,
|
188 |
+
drop_path_rate=0.,
|
189 |
+
norm_layer=nn.LayerNorm,
|
190 |
+
**kwargs):
|
191 |
+
super().__init__()
|
192 |
+
|
193 |
+
self.num_features = self.embed_dim = embed_dim
|
194 |
+
|
195 |
+
self.patch_embed = PatchEmbed(
|
196 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
197 |
+
num_patches = self.patch_embed.num_patches
|
198 |
+
|
199 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
200 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
201 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
202 |
+
|
203 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
204 |
+
self.blocks = nn.ModuleList([
|
205 |
+
Block(
|
206 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
207 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
208 |
+
for i in range(depth)])
|
209 |
+
self.norm = norm_layer(embed_dim)
|
210 |
+
|
211 |
+
# Classifier head
|
212 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
213 |
+
|
214 |
+
trunc_normal_(self.pos_embed, std=.02)
|
215 |
+
trunc_normal_(self.cls_token, std=.02)
|
216 |
+
self.apply(self._init_weights)
|
217 |
+
|
218 |
+
def _init_weights(self, m):
|
219 |
+
if isinstance(m, nn.Linear):
|
220 |
+
trunc_normal_(m.weight, std=.02)
|
221 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
222 |
+
nn.init.constant_(m.bias, 0)
|
223 |
+
elif isinstance(m, nn.LayerNorm):
|
224 |
+
nn.init.constant_(m.bias, 0)
|
225 |
+
nn.init.constant_(m.weight, 1.0)
|
226 |
+
|
227 |
+
def interpolate_pos_encoding(self, x, w, h):
|
228 |
+
npatch = x.shape[1] - 1
|
229 |
+
N = self.pos_embed.shape[1] - 1
|
230 |
+
if npatch == N and w == h:
|
231 |
+
return self.pos_embed
|
232 |
+
class_pos_embed = self.pos_embed[:, 0]
|
233 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
234 |
+
dim = x.shape[-1]
|
235 |
+
w0 = w // self.patch_embed.patch_size
|
236 |
+
h0 = h // self.patch_embed.patch_size
|
237 |
+
# we add a small number to avoid floating point error in the interpolation
|
238 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
239 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
240 |
+
patch_pos_embed = nn.functional.interpolate(
|
241 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
242 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
243 |
+
mode='bicubic',
|
244 |
+
)
|
245 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
246 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, -1, dim)
|
247 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
248 |
+
|
249 |
+
def prepare_tokens(self, x):
|
250 |
+
B, nc, w, h = x.shape
|
251 |
+
x = self.patch_embed(x) # patch linear embedding
|
252 |
+
|
253 |
+
# add the [CLS] token to the embed patch tokens
|
254 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
255 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
256 |
+
|
257 |
+
# add positional encoding to each token
|
258 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
259 |
+
|
260 |
+
return self.pos_drop(x)
|
261 |
+
|
262 |
+
def forward(self, x):
|
263 |
+
x = self.prepare_tokens(x)
|
264 |
+
for blk in self.blocks:
|
265 |
+
x = blk(x)
|
266 |
+
x = self.norm(x)
|
267 |
+
return x[:, 0]
|
268 |
+
|
269 |
+
def forward_feats(self, x):
|
270 |
+
x = self.prepare_tokens(x)
|
271 |
+
for blk in self.blocks:
|
272 |
+
x = blk(x)
|
273 |
+
x = self.norm(x)
|
274 |
+
return x
|
275 |
+
|
276 |
+
def get_intermediate_feat(self, x, n=1, norm=True):
|
277 |
+
x = self.prepare_tokens(x)
|
278 |
+
# we return the output tokens from the `n` last blocks
|
279 |
+
feat = []
|
280 |
+
attns = []
|
281 |
+
qkvs = []
|
282 |
+
for i, blk in enumerate(self.blocks):
|
283 |
+
x, attn, qkv = blk(x, return_qkv=True)
|
284 |
+
if len(self.blocks) - i <= n:
|
285 |
+
if norm:
|
286 |
+
feat.append(self.norm(x))
|
287 |
+
else:
|
288 |
+
feat.append(x)
|
289 |
+
qkvs.append(qkv)
|
290 |
+
attns.append(attn)
|
291 |
+
return feat, attns, qkvs
|
292 |
+
|
293 |
+
def get_last_selfattention(self, x):
|
294 |
+
x = self.prepare_tokens(x)
|
295 |
+
for i, blk in enumerate(self.blocks):
|
296 |
+
if i < len(self.blocks) - 1:
|
297 |
+
x = blk(x)
|
298 |
+
else:
|
299 |
+
# return attention of the last block
|
300 |
+
return blk(x, return_attention=True)
|
301 |
+
|
302 |
+
def get_intermediate_layers(self, x, n=1):
|
303 |
+
x = self.prepare_tokens(x)
|
304 |
+
# we return the output tokens from the `n` last blocks
|
305 |
+
output = []
|
306 |
+
for i, blk in enumerate(self.blocks):
|
307 |
+
x = blk(x)
|
308 |
+
if len(self.blocks) - i <= n:
|
309 |
+
output.append(self.norm(x))
|
310 |
+
return output
|
311 |
+
|
312 |
+
|
313 |
+
def vit_tiny(patch_size=16, **kwargs):
|
314 |
+
model = VisionTransformer(
|
315 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
316 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
|
317 |
+
return model
|
318 |
+
|
319 |
+
|
320 |
+
def vit_small(patch_size=16, **kwargs):
|
321 |
+
model = VisionTransformer(
|
322 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
323 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
def vit_base(patch_size=16, **kwargs):
|
328 |
+
model = VisionTransformer(
|
329 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
330 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=eps), **kwargs)
|
331 |
+
return model
|
332 |
+
|
333 |
+
|
334 |
+
class DINOHead(nn.Module):
|
335 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048,
|
336 |
+
bottleneck_dim=256):
|
337 |
+
super().__init__()
|
338 |
+
nlayers = max(nlayers, 1)
|
339 |
+
if nlayers == 1:
|
340 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
341 |
+
else:
|
342 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
343 |
+
if use_bn:
|
344 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
345 |
+
layers.append(nn.GELU())
|
346 |
+
for _ in range(nlayers - 2):
|
347 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
348 |
+
if use_bn:
|
349 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
350 |
+
layers.append(nn.GELU())
|
351 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
352 |
+
self.mlp = nn.Sequential(*layers)
|
353 |
+
self.apply(self._init_weights)
|
354 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
355 |
+
self.last_layer.weight_g.data.fill_(1)
|
356 |
+
if norm_last_layer:
|
357 |
+
self.last_layer.weight_g.requires_grad = False
|
358 |
+
|
359 |
+
def _init_weights(self, m):
|
360 |
+
if isinstance(m, nn.Linear):
|
361 |
+
trunc_normal_(m.weight, std=.02)
|
362 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
363 |
+
nn.init.constant_(m.bias, 0)
|
364 |
+
|
365 |
+
def forward(self, x):
|
366 |
+
x = self.mlp(x)
|
367 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
368 |
+
x = self.last_layer(x)
|
369 |
+
return x
|
370 |
+
|
371 |
+
|
372 |
+
|
373 |
+
class DINOFeaturizer(nn.Module):
|
374 |
+
|
375 |
+
def __init__(self, arch, patch_size, feat_type):
|
376 |
+
super().__init__()
|
377 |
+
self.arch = arch
|
378 |
+
self.patch_size = patch_size
|
379 |
+
self.feat_type = feat_type
|
380 |
+
|
381 |
+
self.config = {
|
382 |
+
"arch": arch,
|
383 |
+
"patch_size": patch_size,
|
384 |
+
"feat_type": feat_type
|
385 |
+
}
|
386 |
+
|
387 |
+
self.model = vit_small(
|
388 |
+
patch_size=patch_size,
|
389 |
+
num_classes=0)
|
390 |
+
|
391 |
+
if "3d-dino" in arch:
|
392 |
+
state_dict = torch.load("../models/3d-dino-co3d.pth")["teacher"]
|
393 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
394 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
395 |
+
elif "iarpa-dino" in arch:
|
396 |
+
state_dict = torch.load("../models/dino_iarpa.pth")["teacher"]
|
397 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
398 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
399 |
+
elif "chk-dino" in arch:
|
400 |
+
state_dict = torch.load("../models/dino_deitsmall16_pretrain_full_checkpoint.pth")["teacher"]
|
401 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
402 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
403 |
+
elif "ft_dino" in arch:
|
404 |
+
arch = "_".join(arch.split("_")[:-1])
|
405 |
+
state_dict = torch.load("../models/{}.pth".format(arch))["teacher"]
|
406 |
+
state_dict = {k.replace("module.", "").replace("backbone.", ""): v for k, v in state_dict.items()}
|
407 |
+
state_dict = {k: v for k, v in state_dict.items() if "head." not in k}
|
408 |
+
elif "dino" in arch:
|
409 |
+
state_dict = torch.hub.load('facebookresearch/dino:main', self.arch).state_dict()
|
410 |
+
else: # model from timm -- load weights from timm to dino model (enables working on arbitrary size images).
|
411 |
+
temp_model = timm.create_model(self.arch, pretrained=True)
|
412 |
+
state_dict = temp_model.state_dict()
|
413 |
+
del state_dict['head.weight']
|
414 |
+
del state_dict['head.bias']
|
415 |
+
|
416 |
+
self.model.load_state_dict(state_dict, strict=True)
|
417 |
+
|
418 |
+
if arch == "vit_small":
|
419 |
+
self.n_feats = 384
|
420 |
+
else:
|
421 |
+
self.n_feats = 768
|
422 |
+
|
423 |
+
def get_cls_token(self, img):
|
424 |
+
return self.model.forward(img)
|
425 |
+
|
426 |
+
def forward(self, img, n=1, include_cls=False):
|
427 |
+
assert (img.shape[2] % self.patch_size == 0)
|
428 |
+
assert (img.shape[3] % self.patch_size == 0)
|
429 |
+
|
430 |
+
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
|
431 |
+
feat, attn, qkv = feat[0], attn[0], qkv[0]
|
432 |
+
|
433 |
+
feat_h = img.shape[2] // self.patch_size
|
434 |
+
feat_w = img.shape[3] // self.patch_size
|
435 |
+
|
436 |
+
if self.feat_type == "token":
|
437 |
+
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
|
438 |
+
cls_feat = feat[:, 0, :]
|
439 |
+
elif self.feat_type == "key":
|
440 |
+
x = qkv[1, :, :, 1:, :] # remove cls token
|
441 |
+
desc = x.permute(0, 2, 3, 1).flatten(start_dim=-2, end_dim=-1)
|
442 |
+
image_feat = desc.reshape(desc.shape[0], feat_h, feat_w, desc.shape[2]) \
|
443 |
+
.permute(0, 3, 1, 2)
|
444 |
+
cls_feat = None
|
445 |
+
else:
|
446 |
+
raise ValueError("Unknown feat type:{}".format(self.feat_type))
|
447 |
+
|
448 |
+
if include_cls:
|
449 |
+
return image_feat, cls_feat
|
450 |
+
|
451 |
+
return image_feat
|
DenseAV/denseav/featurizers/DINOv2.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class DINOv2Featurizer(nn.Module):
|
6 |
+
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14').cuda()
|
10 |
+
# self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
11 |
+
self.model.eval()
|
12 |
+
self.config = {}
|
13 |
+
|
14 |
+
def get_cls_token(self, img):
|
15 |
+
pass
|
16 |
+
|
17 |
+
def forward(self, img, include_cls):
|
18 |
+
feature_dict = self.model.forward_features(img)
|
19 |
+
_, _, h, w = img.shape
|
20 |
+
new_h, new_w = h // 14, w // 14
|
21 |
+
b, _, c = feature_dict["x_norm_patchtokens"].shape
|
22 |
+
spatial_tokens = feature_dict["x_norm_patchtokens"].permute(0, 2, 1).reshape(b, c, new_h, new_w)
|
23 |
+
|
24 |
+
if include_cls:
|
25 |
+
return spatial_tokens, feature_dict["x_norm_clstoken"]
|
26 |
+
else:
|
27 |
+
return spatial_tokens
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
import torchvision.transforms as T
|
32 |
+
from PIL import Image
|
33 |
+
from shared import norm, crop_to_divisor
|
34 |
+
|
35 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
36 |
+
|
37 |
+
image = Image.open("../../samples/dog_man_1_crop.jpg")
|
38 |
+
load_size = 224 # * 3
|
39 |
+
transform = T.Compose([
|
40 |
+
T.Resize(load_size, Image.BILINEAR),
|
41 |
+
T.CenterCrop(load_size),
|
42 |
+
T.ToTensor(),
|
43 |
+
norm])
|
44 |
+
|
45 |
+
model = DINOv2Featurizer().cuda()
|
46 |
+
|
47 |
+
results = model(transform(image).cuda().unsqueeze(0), include_cls=False)
|
48 |
+
|
49 |
+
print(results.shape)
|
DenseAV/denseav/featurizers/Hubert.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import Wav2Vec2Processor, HubertModel, HubertConfig
|
4 |
+
from transformers.pytorch_utils import Conv1D
|
5 |
+
|
6 |
+
class HubertAudioTransform():
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
self.processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
|
10 |
+
|
11 |
+
def __call__(self, audio):
|
12 |
+
return self.processor(audio, return_tensors="pt", sampling_rate=16000).input_values.squeeze(0)
|
13 |
+
|
14 |
+
|
15 |
+
def copy_conv(l):
|
16 |
+
new_l = Conv1D()
|
17 |
+
|
18 |
+
|
19 |
+
class Hubert(nn.Module):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
model1 = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft")
|
23 |
+
config = model1.config
|
24 |
+
del model1
|
25 |
+
config.layer_norm_eps = 1e-4
|
26 |
+
self.model = HubertModel.from_pretrained("facebook/hubert-large-ls960-ft", config=config)
|
27 |
+
self.config = dict()
|
28 |
+
|
29 |
+
|
30 |
+
def forward(self, audio, include_cls):
|
31 |
+
outputs = self.model(audio)
|
32 |
+
# outputs = deepspeed.checkpointing.checkpoint(self.model, audio)
|
33 |
+
|
34 |
+
patch_tokens = outputs.last_hidden_state.permute(0, 2, 1).unsqueeze(2)
|
35 |
+
|
36 |
+
# return patch_tokens
|
37 |
+
if include_cls:
|
38 |
+
return patch_tokens, None
|
39 |
+
else:
|
40 |
+
return patch_tokens
|
41 |
+
|
42 |
+
def get_last_params(self):
|
43 |
+
return self.model.encoder.layers[-1].parameters()
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
import librosa
|
48 |
+
from shared import pca, remove_axes
|
49 |
+
import matplotlib.pyplot as plt
|
50 |
+
from pytorch_lightning import seed_everything
|
51 |
+
|
52 |
+
audio, _ = librosa.load("../../samples/example.wav", sr=16000)
|
53 |
+
audio = torch.from_numpy(audio).unsqueeze(0).to("cuda")
|
54 |
+
|
55 |
+
model = Hubert().to("cuda")
|
56 |
+
embeddings = model.forward(audio, include_cls=False)
|
57 |
+
|
58 |
+
print(embeddings.shape)
|
59 |
+
seed_everything(0)
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
[pca_feats], _ = pca([embeddings])
|
63 |
+
pca_feats = torch.broadcast_to(
|
64 |
+
pca_feats, (pca_feats.shape[0], pca_feats.shape[1], 25, pca_feats.shape[3]))
|
65 |
+
fig, axes = plt.subplots(2, 1, figsize=(10, 7))
|
66 |
+
axes[1].imshow(pca_feats.cpu().squeeze(0).permute(1, 2, 0))
|
67 |
+
remove_axes(axes)
|
68 |
+
plt.tight_layout()
|
69 |
+
plt.show()
|
70 |
+
print("here")
|
DenseAV/denseav/featurizers/ImageBind.py
ADDED
@@ -0,0 +1,2033 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import io
|
4 |
+
import logging
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
from functools import lru_cache
|
8 |
+
from functools import partial
|
9 |
+
from types import SimpleNamespace
|
10 |
+
from typing import Callable, List
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
import einops
|
14 |
+
import ftfy
|
15 |
+
import numpy as np
|
16 |
+
import regex as re
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import torch.utils.checkpoint as checkpoint
|
21 |
+
import torchaudio
|
22 |
+
import torchvision.transforms as T
|
23 |
+
from PIL import Image
|
24 |
+
from timm.models.layers import DropPath, trunc_normal_
|
25 |
+
from torchvision import transforms
|
26 |
+
import matplotlib.pyplot as plt
|
27 |
+
from iopath.common.file_io import g_pathmgr
|
28 |
+
|
29 |
+
|
30 |
+
class Attention(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
num_heads=8,
|
35 |
+
qkv_bias=False,
|
36 |
+
qk_scale=None,
|
37 |
+
attn_drop=0.0,
|
38 |
+
proj_drop=0.0,
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
self.num_heads = num_heads
|
42 |
+
head_dim = dim // num_heads
|
43 |
+
# NOTE scale factor was wrong in my original version,
|
44 |
+
# can set manually to be compat with prev weights
|
45 |
+
self.scale = qk_scale or head_dim ** -0.5
|
46 |
+
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
49 |
+
self.proj = nn.Linear(dim, dim)
|
50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
B, N, C = x.shape
|
54 |
+
qkv = (
|
55 |
+
self.qkv(x)
|
56 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
57 |
+
.permute(2, 0, 3, 1, 4)
|
58 |
+
)
|
59 |
+
q, k, v = (
|
60 |
+
qkv[0],
|
61 |
+
qkv[1],
|
62 |
+
qkv[2],
|
63 |
+
) # make torchscript happy (cannot use tensor as tuple)
|
64 |
+
|
65 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
66 |
+
attn = attn.softmax(dim=-1)
|
67 |
+
attn = self.attn_drop(attn)
|
68 |
+
|
69 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
70 |
+
x = self.proj(x)
|
71 |
+
x = self.proj_drop(x)
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
class Mlp(nn.Module):
|
76 |
+
def __init__(
|
77 |
+
self,
|
78 |
+
in_features,
|
79 |
+
hidden_features=None,
|
80 |
+
out_features=None,
|
81 |
+
act_layer=nn.GELU,
|
82 |
+
drop=0.0,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
out_features = out_features or in_features
|
86 |
+
hidden_features = hidden_features or in_features
|
87 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
88 |
+
self.act = act_layer()
|
89 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
90 |
+
self.drop = nn.Dropout(drop)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
x = self.fc1(x)
|
94 |
+
x = self.act(x)
|
95 |
+
x = self.drop(x)
|
96 |
+
x = self.fc2(x)
|
97 |
+
x = self.drop(x)
|
98 |
+
return x
|
99 |
+
|
100 |
+
|
101 |
+
class MultiheadAttention(nn.MultiheadAttention):
|
102 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
|
103 |
+
return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
104 |
+
|
105 |
+
|
106 |
+
class ViTAttention(Attention):
|
107 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
|
108 |
+
assert attn_mask is None
|
109 |
+
return super().forward(x)
|
110 |
+
|
111 |
+
|
112 |
+
class BlockWithMasking(nn.Module):
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
dim: int,
|
116 |
+
attn_target: Callable,
|
117 |
+
mlp_ratio: int = 4,
|
118 |
+
act_layer: Callable = nn.GELU,
|
119 |
+
norm_layer: Callable = nn.LayerNorm,
|
120 |
+
ffn_dropout_rate: float = 0.0,
|
121 |
+
drop_path: float = 0.0,
|
122 |
+
layer_scale_type: str = None,
|
123 |
+
layer_scale_init_value: float = 1e-4,
|
124 |
+
):
|
125 |
+
super().__init__()
|
126 |
+
|
127 |
+
assert not isinstance(
|
128 |
+
attn_target, nn.Module
|
129 |
+
), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
|
130 |
+
self.attn = attn_target()
|
131 |
+
if drop_path > 0.0:
|
132 |
+
self.drop_path = DropPath(drop_path)
|
133 |
+
else:
|
134 |
+
self.drop_path = nn.Identity()
|
135 |
+
self.norm_1 = norm_layer(dim)
|
136 |
+
mlp_hidden_dim = int(mlp_ratio * dim)
|
137 |
+
self.mlp = Mlp(
|
138 |
+
in_features=dim,
|
139 |
+
hidden_features=mlp_hidden_dim,
|
140 |
+
act_layer=act_layer,
|
141 |
+
drop=ffn_dropout_rate,
|
142 |
+
)
|
143 |
+
self.norm_2 = norm_layer(dim)
|
144 |
+
self.layer_scale_type = layer_scale_type
|
145 |
+
if self.layer_scale_type is not None:
|
146 |
+
assert self.layer_scale_type in [
|
147 |
+
"per_channel",
|
148 |
+
"scalar",
|
149 |
+
], f"Found Layer scale type {self.layer_scale_type}"
|
150 |
+
if self.layer_scale_type == "per_channel":
|
151 |
+
# one gamma value per channel
|
152 |
+
gamma_shape = [1, 1, dim]
|
153 |
+
elif self.layer_scale_type == "scalar":
|
154 |
+
# single gamma value for all channels
|
155 |
+
gamma_shape = [1, 1, 1]
|
156 |
+
# two gammas: for each part of the fwd in the encoder
|
157 |
+
self.layer_scale_gamma1 = nn.Parameter(
|
158 |
+
torch.ones(size=gamma_shape) * layer_scale_init_value,
|
159 |
+
requires_grad=True,
|
160 |
+
)
|
161 |
+
self.layer_scale_gamma2 = nn.Parameter(
|
162 |
+
torch.ones(size=gamma_shape) * layer_scale_init_value,
|
163 |
+
requires_grad=True,
|
164 |
+
)
|
165 |
+
|
166 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
|
167 |
+
if self.layer_scale_type is None:
|
168 |
+
x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
|
169 |
+
x = x + self.drop_path(self.mlp(self.norm_2(x)))
|
170 |
+
else:
|
171 |
+
x = (
|
172 |
+
x
|
173 |
+
+ self.drop_path(self.attn(self.norm_1(x), attn_mask))
|
174 |
+
* self.layer_scale_gamma1
|
175 |
+
)
|
176 |
+
x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
|
177 |
+
return x
|
178 |
+
|
179 |
+
|
180 |
+
_LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
|
181 |
+
|
182 |
+
|
183 |
+
class SimpleTransformer(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
attn_target: Callable,
|
187 |
+
embed_dim: int,
|
188 |
+
num_blocks: int,
|
189 |
+
block: Callable = BlockWithMasking,
|
190 |
+
pre_transformer_layer: Callable = None,
|
191 |
+
post_transformer_layer: Callable = None,
|
192 |
+
drop_path_rate: float = 0.0,
|
193 |
+
drop_path_type: str = "progressive",
|
194 |
+
norm_layer: Callable = _LAYER_NORM,
|
195 |
+
mlp_ratio: int = 4,
|
196 |
+
ffn_dropout_rate: float = 0.0,
|
197 |
+
layer_scale_type: str = None, # from cait; possible values are None, "per_channel", "scalar"
|
198 |
+
layer_scale_init_value: float = 1e-4, # from cait; float
|
199 |
+
weight_init_style: str = "jax", # possible values jax or pytorch
|
200 |
+
):
|
201 |
+
"""
|
202 |
+
Simple Transformer with the following features
|
203 |
+
1. Supports masked attention
|
204 |
+
2. Supports DropPath
|
205 |
+
3. Supports LayerScale
|
206 |
+
4. Supports Dropout in Attention and FFN
|
207 |
+
5. Makes few assumptions about the input except that it is a Tensor
|
208 |
+
"""
|
209 |
+
super().__init__()
|
210 |
+
self.pre_transformer_layer = pre_transformer_layer
|
211 |
+
if drop_path_type == "progressive":
|
212 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
|
213 |
+
elif drop_path_type == "uniform":
|
214 |
+
dpr = [drop_path_rate for i in range(num_blocks)]
|
215 |
+
else:
|
216 |
+
raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
|
217 |
+
|
218 |
+
self.blocks = nn.Sequential(
|
219 |
+
*[
|
220 |
+
block(
|
221 |
+
dim=embed_dim,
|
222 |
+
attn_target=attn_target,
|
223 |
+
mlp_ratio=mlp_ratio,
|
224 |
+
ffn_dropout_rate=ffn_dropout_rate,
|
225 |
+
drop_path=dpr[i],
|
226 |
+
norm_layer=norm_layer,
|
227 |
+
layer_scale_type=layer_scale_type,
|
228 |
+
layer_scale_init_value=layer_scale_init_value,
|
229 |
+
)
|
230 |
+
for i in range(num_blocks)
|
231 |
+
]
|
232 |
+
)
|
233 |
+
self.post_transformer_layer = post_transformer_layer
|
234 |
+
self.weight_init_style = weight_init_style
|
235 |
+
self.apply(self._init_weights)
|
236 |
+
|
237 |
+
def _init_weights(self, m):
|
238 |
+
if isinstance(m, nn.Linear):
|
239 |
+
if self.weight_init_style == "jax":
|
240 |
+
# Based on MAE and official Jax ViT implementation
|
241 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
242 |
+
elif self.weight_init_style == "pytorch":
|
243 |
+
# PyTorch ViT uses trunc_normal_
|
244 |
+
trunc_normal_(m.weight, std=0.02)
|
245 |
+
|
246 |
+
if m.bias is not None:
|
247 |
+
nn.init.constant_(m.bias, 0)
|
248 |
+
elif isinstance(m, (nn.LayerNorm)):
|
249 |
+
nn.init.constant_(m.bias, 0)
|
250 |
+
nn.init.constant_(m.weight, 1.0)
|
251 |
+
|
252 |
+
def forward(
|
253 |
+
self,
|
254 |
+
tokens: torch.Tensor,
|
255 |
+
attn_mask: torch.Tensor = None,
|
256 |
+
use_checkpoint: bool = False,
|
257 |
+
checkpoint_every_n: int = 1,
|
258 |
+
checkpoint_blk_ids: List[int] = None,
|
259 |
+
):
|
260 |
+
"""
|
261 |
+
Inputs
|
262 |
+
- tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
|
263 |
+
- attn: mask of shape L x L
|
264 |
+
|
265 |
+
Output
|
266 |
+
- x: data of shape N x L x D (or L x N x D depending on the attention implementation)
|
267 |
+
"""
|
268 |
+
if self.pre_transformer_layer:
|
269 |
+
tokens = self.pre_transformer_layer(tokens)
|
270 |
+
if use_checkpoint and checkpoint_blk_ids is None:
|
271 |
+
checkpoint_blk_ids = [
|
272 |
+
blk_id
|
273 |
+
for blk_id in range(len(self.blocks))
|
274 |
+
if blk_id % checkpoint_every_n == 0
|
275 |
+
]
|
276 |
+
if checkpoint_blk_ids:
|
277 |
+
checkpoint_blk_ids = set(checkpoint_blk_ids)
|
278 |
+
for blk_id, blk in enumerate(self.blocks):
|
279 |
+
if use_checkpoint and blk_id in checkpoint_blk_ids:
|
280 |
+
tokens = checkpoint.checkpoint(
|
281 |
+
blk, tokens, attn_mask, use_reentrant=False
|
282 |
+
)
|
283 |
+
else:
|
284 |
+
tokens = blk(tokens, attn_mask=attn_mask)
|
285 |
+
if self.post_transformer_layer:
|
286 |
+
tokens = self.post_transformer_layer(tokens)
|
287 |
+
return tokens
|
288 |
+
|
289 |
+
|
290 |
+
def get_sinusoid_encoding_table(n_position, d_hid):
|
291 |
+
"""Sinusoid position encoding table"""
|
292 |
+
|
293 |
+
# TODO: make it with torch instead of numpy
|
294 |
+
def get_position_angle_vec(position):
|
295 |
+
return [
|
296 |
+
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
297 |
+
for hid_j in range(d_hid)
|
298 |
+
]
|
299 |
+
|
300 |
+
sinusoid_table = np.array(
|
301 |
+
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
302 |
+
)
|
303 |
+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
304 |
+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
305 |
+
|
306 |
+
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
307 |
+
|
308 |
+
|
309 |
+
def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
|
310 |
+
N = pos_embed.shape[1]
|
311 |
+
if N == target_spatial_size:
|
312 |
+
return pos_embed
|
313 |
+
dim = pos_embed.shape[-1]
|
314 |
+
# nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
|
315 |
+
pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
|
316 |
+
pos_embed = nn.functional.interpolate(
|
317 |
+
pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
|
318 |
+
0, 3, 1, 2
|
319 |
+
),
|
320 |
+
scale_factor=math.sqrt(target_spatial_size / N),
|
321 |
+
mode="bicubic",
|
322 |
+
)
|
323 |
+
if updated:
|
324 |
+
pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
|
325 |
+
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
326 |
+
return pos_embed
|
327 |
+
|
328 |
+
|
329 |
+
def interpolate_pos_encoding(
|
330 |
+
npatch_per_img,
|
331 |
+
pos_embed,
|
332 |
+
patches_layout,
|
333 |
+
input_shape=None,
|
334 |
+
first_patch_idx=1,
|
335 |
+
):
|
336 |
+
assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
|
337 |
+
N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
|
338 |
+
if npatch_per_img == N:
|
339 |
+
return pos_embed
|
340 |
+
|
341 |
+
# assert (
|
342 |
+
# patches_layout[-1] == patches_layout[-2]
|
343 |
+
# ), "Interpolation of pos embed not supported for non-square layouts"
|
344 |
+
|
345 |
+
class_emb = pos_embed[:, :first_patch_idx]
|
346 |
+
pos_embed = pos_embed[:, first_patch_idx:]
|
347 |
+
|
348 |
+
if input_shape is None or patches_layout[0] == 1:
|
349 |
+
# simple 2D pos embedding, no temporal component
|
350 |
+
pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
|
351 |
+
elif patches_layout[0] > 1:
|
352 |
+
# pos embed has a temporal component
|
353 |
+
assert len(input_shape) == 4, "temporal interpolation not supported"
|
354 |
+
# we only support 2D interpolation in this case
|
355 |
+
num_frames = patches_layout[0]
|
356 |
+
num_spatial_tokens = patches_layout[1] * patches_layout[2]
|
357 |
+
pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
|
358 |
+
# interpolate embedding for zeroth frame
|
359 |
+
pos_embed = interpolate_pos_encoding_2d(
|
360 |
+
npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
|
361 |
+
)
|
362 |
+
else:
|
363 |
+
raise ValueError("This type of interpolation isn't implemented")
|
364 |
+
|
365 |
+
return torch.cat((class_emb, pos_embed), dim=1)
|
366 |
+
|
367 |
+
|
368 |
+
def _get_pos_embedding(
|
369 |
+
npatch_per_img,
|
370 |
+
pos_embed,
|
371 |
+
patches_layout,
|
372 |
+
input_shape,
|
373 |
+
first_patch_idx=1,
|
374 |
+
):
|
375 |
+
pos_embed = interpolate_pos_encoding(
|
376 |
+
npatch_per_img,
|
377 |
+
pos_embed,
|
378 |
+
patches_layout,
|
379 |
+
input_shape=input_shape,
|
380 |
+
first_patch_idx=first_patch_idx,
|
381 |
+
)
|
382 |
+
return pos_embed
|
383 |
+
|
384 |
+
|
385 |
+
class VerboseNNModule(nn.Module):
|
386 |
+
"""
|
387 |
+
Wrapper around nn.Module that prints registered buffers and parameter names.
|
388 |
+
"""
|
389 |
+
|
390 |
+
@staticmethod
|
391 |
+
def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
|
392 |
+
st = (
|
393 |
+
"("
|
394 |
+
+ name
|
395 |
+
+ "): "
|
396 |
+
+ "tensor("
|
397 |
+
+ str(tuple(tensor[1].shape))
|
398 |
+
+ ", requires_grad="
|
399 |
+
+ str(tensor[1].requires_grad)
|
400 |
+
+ ")\n"
|
401 |
+
)
|
402 |
+
return st
|
403 |
+
|
404 |
+
def extra_repr(self) -> str:
|
405 |
+
named_modules = set()
|
406 |
+
for p in self.named_modules():
|
407 |
+
named_modules.update([p[0]])
|
408 |
+
named_modules = list(named_modules)
|
409 |
+
|
410 |
+
string_repr = ""
|
411 |
+
for p in self.named_parameters():
|
412 |
+
name = p[0].split(".")[0]
|
413 |
+
if name not in named_modules:
|
414 |
+
string_repr += self.get_readable_tensor_repr(name, p)
|
415 |
+
|
416 |
+
for p in self.named_buffers():
|
417 |
+
name = p[0].split(".")[0]
|
418 |
+
string_repr += self.get_readable_tensor_repr(name, p)
|
419 |
+
|
420 |
+
return string_repr
|
421 |
+
|
422 |
+
|
423 |
+
class PatchEmbedGeneric(nn.Module):
|
424 |
+
"""
|
425 |
+
PatchEmbed from Hydra
|
426 |
+
"""
|
427 |
+
|
428 |
+
def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
|
429 |
+
super().__init__()
|
430 |
+
|
431 |
+
if len(proj_stem) > 1:
|
432 |
+
self.proj = nn.Sequential(*proj_stem)
|
433 |
+
else:
|
434 |
+
# Special case to be able to load pre-trained models that were
|
435 |
+
# trained with a standard stem
|
436 |
+
self.proj = proj_stem[0]
|
437 |
+
self.norm_layer = norm_layer
|
438 |
+
|
439 |
+
def get_patch_layout(self, img_size):
|
440 |
+
with torch.no_grad():
|
441 |
+
dummy_img = torch.zeros(
|
442 |
+
[
|
443 |
+
1,
|
444 |
+
]
|
445 |
+
+ img_size
|
446 |
+
)
|
447 |
+
dummy_out = self.proj(dummy_img)
|
448 |
+
embed_dim = dummy_out.shape[1]
|
449 |
+
patches_layout = tuple(dummy_out.shape[2:])
|
450 |
+
num_patches = np.prod(patches_layout)
|
451 |
+
return patches_layout, num_patches, embed_dim
|
452 |
+
|
453 |
+
def forward(self, x):
|
454 |
+
x = self.proj(x)
|
455 |
+
# B C (T) H W -> B (T)HW C
|
456 |
+
x = x.flatten(2).transpose(1, 2)
|
457 |
+
if self.norm_layer is not None:
|
458 |
+
x = self.norm_layer(x)
|
459 |
+
return x
|
460 |
+
|
461 |
+
|
462 |
+
class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
|
463 |
+
def __init__(
|
464 |
+
self,
|
465 |
+
patches_layout: List,
|
466 |
+
num_patches: int,
|
467 |
+
num_cls_tokens: int,
|
468 |
+
embed_dim: int,
|
469 |
+
learnable: bool,
|
470 |
+
) -> None:
|
471 |
+
super().__init__()
|
472 |
+
self.num_cls_tokens = num_cls_tokens
|
473 |
+
self.patches_layout = patches_layout
|
474 |
+
self.num_patches = num_patches
|
475 |
+
self.num_tokens = num_cls_tokens + num_patches
|
476 |
+
self.learnable = learnable
|
477 |
+
if self.learnable:
|
478 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
|
479 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
480 |
+
else:
|
481 |
+
self.register_buffer(
|
482 |
+
"pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
|
483 |
+
)
|
484 |
+
|
485 |
+
def get_pos_embedding(self, vision_input, all_vision_tokens):
|
486 |
+
input_shape = vision_input.shape
|
487 |
+
pos_embed = _get_pos_embedding(
|
488 |
+
all_vision_tokens.size(1) - self.num_cls_tokens,
|
489 |
+
pos_embed=self.pos_embed,
|
490 |
+
patches_layout=self.patches_layout,
|
491 |
+
input_shape=input_shape,
|
492 |
+
first_patch_idx=self.num_cls_tokens,
|
493 |
+
)
|
494 |
+
return pos_embed
|
495 |
+
|
496 |
+
|
497 |
+
class RGBDTPreprocessor(VerboseNNModule):
|
498 |
+
def __init__(
|
499 |
+
self,
|
500 |
+
rgbt_stem: PatchEmbedGeneric,
|
501 |
+
depth_stem: PatchEmbedGeneric,
|
502 |
+
img_size: List = (3, 224, 224),
|
503 |
+
num_cls_tokens: int = 1,
|
504 |
+
pos_embed_fn: Callable = None,
|
505 |
+
use_type_embed: bool = False,
|
506 |
+
init_param_style: str = "openclip",
|
507 |
+
) -> None:
|
508 |
+
super().__init__()
|
509 |
+
stem = rgbt_stem if rgbt_stem is not None else depth_stem
|
510 |
+
(
|
511 |
+
self.patches_layout,
|
512 |
+
self.num_patches,
|
513 |
+
self.embed_dim,
|
514 |
+
) = stem.get_patch_layout(img_size)
|
515 |
+
self.rgbt_stem = rgbt_stem
|
516 |
+
self.depth_stem = depth_stem
|
517 |
+
self.use_pos_embed = pos_embed_fn is not None
|
518 |
+
self.use_type_embed = use_type_embed
|
519 |
+
self.num_cls_tokens = num_cls_tokens
|
520 |
+
|
521 |
+
if self.use_pos_embed:
|
522 |
+
self.pos_embedding_helper = pos_embed_fn(
|
523 |
+
patches_layout=self.patches_layout,
|
524 |
+
num_cls_tokens=num_cls_tokens,
|
525 |
+
num_patches=self.num_patches,
|
526 |
+
embed_dim=self.embed_dim,
|
527 |
+
)
|
528 |
+
if self.num_cls_tokens > 0:
|
529 |
+
self.cls_token = nn.Parameter(
|
530 |
+
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
|
531 |
+
)
|
532 |
+
if self.use_type_embed:
|
533 |
+
self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
534 |
+
|
535 |
+
self.init_parameters(init_param_style)
|
536 |
+
|
537 |
+
@torch.no_grad()
|
538 |
+
def init_parameters(self, init_param_style):
|
539 |
+
if init_param_style == "openclip":
|
540 |
+
# OpenCLIP style initialization
|
541 |
+
scale = self.embed_dim ** -0.5
|
542 |
+
if self.use_pos_embed:
|
543 |
+
nn.init.normal_(self.pos_embedding_helper.pos_embed)
|
544 |
+
self.pos_embedding_helper.pos_embed *= scale
|
545 |
+
|
546 |
+
if self.num_cls_tokens > 0:
|
547 |
+
nn.init.normal_(self.cls_token)
|
548 |
+
self.cls_token *= scale
|
549 |
+
elif init_param_style == "vit":
|
550 |
+
self.cls_token.data.fill_(0)
|
551 |
+
else:
|
552 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
553 |
+
|
554 |
+
if self.use_type_embed:
|
555 |
+
nn.init.normal_(self.type_embed)
|
556 |
+
|
557 |
+
def get_pos_emb_2(self, input, stem):
|
558 |
+
patches = stem.proj(input)
|
559 |
+
target_size = patches.shape[-2:]
|
560 |
+
original_size = list(self.pos_embedding_helper.patches_layout)[-2:]
|
561 |
+
|
562 |
+
orig_ce = self.pos_embedding_helper.pos_embed[:, 0, :]
|
563 |
+
orig_pe = ((self.pos_embedding_helper.pos_embed[:, 1:, :]
|
564 |
+
.reshape(1, *original_size, self.embed_dim))
|
565 |
+
.permute(0, 3, 1, 2))
|
566 |
+
|
567 |
+
new_pe = F.interpolate(orig_pe, size=target_size, mode="bicubic")
|
568 |
+
|
569 |
+
new_full_pe = torch.cat([orig_ce.unsqueeze(1), new_pe.permute(0, 2, 3, 1).reshape(1, -1, self.embed_dim)],
|
570 |
+
dim=1)
|
571 |
+
|
572 |
+
return new_full_pe
|
573 |
+
|
574 |
+
def tokenize_input_and_cls_pos(self, input, stem, mask):
|
575 |
+
# tokens is of shape B x L x D
|
576 |
+
tokens = stem(input)
|
577 |
+
assert tokens.ndim == 3
|
578 |
+
assert tokens.shape[2] == self.embed_dim
|
579 |
+
B = tokens.shape[0]
|
580 |
+
if self.num_cls_tokens > 0:
|
581 |
+
class_tokens = self.cls_token.expand(
|
582 |
+
B, -1, -1
|
583 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
584 |
+
tokens = torch.cat((class_tokens, tokens), dim=1)
|
585 |
+
if self.use_pos_embed:
|
586 |
+
pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
|
587 |
+
# pos_embed = self.get_pos_emb_2(input, stem)
|
588 |
+
tokens = tokens + pos_embed
|
589 |
+
if self.use_type_embed:
|
590 |
+
tokens = tokens + self.type_embed.expand(B, -1, -1)
|
591 |
+
return tokens
|
592 |
+
|
593 |
+
def forward(self, vision=None, depth=None, patch_mask=None):
|
594 |
+
if patch_mask is not None:
|
595 |
+
raise NotImplementedError()
|
596 |
+
|
597 |
+
if vision is not None:
|
598 |
+
vision_tokens = self.tokenize_input_and_cls_pos(
|
599 |
+
vision, self.rgbt_stem, patch_mask
|
600 |
+
)
|
601 |
+
|
602 |
+
if depth is not None:
|
603 |
+
depth_tokens = self.tokenize_input_and_cls_pos(
|
604 |
+
depth, self.depth_stem, patch_mask
|
605 |
+
)
|
606 |
+
|
607 |
+
# aggregate tokens
|
608 |
+
if vision is not None and depth is not None:
|
609 |
+
final_tokens = vision_tokens + depth_tokens
|
610 |
+
else:
|
611 |
+
final_tokens = vision_tokens if vision is not None else depth_tokens
|
612 |
+
return_dict = {
|
613 |
+
"trunk": {
|
614 |
+
"tokens": final_tokens,
|
615 |
+
},
|
616 |
+
"head": {},
|
617 |
+
}
|
618 |
+
return return_dict
|
619 |
+
|
620 |
+
|
621 |
+
class AudioPreprocessor(RGBDTPreprocessor):
|
622 |
+
def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
|
623 |
+
super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
|
624 |
+
|
625 |
+
def forward(self, audio=None):
|
626 |
+
return super().forward(vision=audio)
|
627 |
+
|
628 |
+
|
629 |
+
class ThermalPreprocessor(RGBDTPreprocessor):
|
630 |
+
def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
|
631 |
+
super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
|
632 |
+
|
633 |
+
def forward(self, thermal=None):
|
634 |
+
return super().forward(vision=thermal)
|
635 |
+
|
636 |
+
|
637 |
+
def build_causal_attention_mask(context_length):
|
638 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
639 |
+
# pytorch uses additive attention mask; fill with -inf
|
640 |
+
mask = torch.empty(context_length, context_length, requires_grad=False)
|
641 |
+
mask.fill_(float("-inf"))
|
642 |
+
mask.triu_(1) # zero out the lower diagonal
|
643 |
+
return mask
|
644 |
+
|
645 |
+
|
646 |
+
class TextPreprocessor(VerboseNNModule):
|
647 |
+
def __init__(
|
648 |
+
self,
|
649 |
+
vocab_size: int,
|
650 |
+
context_length: int,
|
651 |
+
embed_dim: int,
|
652 |
+
causal_masking: bool,
|
653 |
+
supply_seq_len_to_head: bool = True,
|
654 |
+
num_cls_tokens: int = 0,
|
655 |
+
init_param_style: str = "openclip",
|
656 |
+
) -> None:
|
657 |
+
super().__init__()
|
658 |
+
self.vocab_size = vocab_size
|
659 |
+
self.context_length = context_length
|
660 |
+
self.token_embedding = nn.Embedding(vocab_size, embed_dim)
|
661 |
+
self.pos_embed = nn.Parameter(
|
662 |
+
torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
|
663 |
+
)
|
664 |
+
self.causal_masking = causal_masking
|
665 |
+
if self.causal_masking:
|
666 |
+
mask = build_causal_attention_mask(self.context_length)
|
667 |
+
# register the mask as a buffer, so it can be moved to the right device
|
668 |
+
self.register_buffer("mask", mask)
|
669 |
+
|
670 |
+
self.supply_seq_len_to_head = supply_seq_len_to_head
|
671 |
+
self.num_cls_tokens = num_cls_tokens
|
672 |
+
self.embed_dim = embed_dim
|
673 |
+
if num_cls_tokens > 0:
|
674 |
+
assert self.causal_masking is False, "Masking + CLS token isn't implemented"
|
675 |
+
self.cls_token = nn.Parameter(
|
676 |
+
torch.zeros(1, self.num_cls_tokens, embed_dim)
|
677 |
+
)
|
678 |
+
|
679 |
+
self.init_parameters(init_param_style)
|
680 |
+
|
681 |
+
@torch.no_grad()
|
682 |
+
def init_parameters(self, init_param_style="openclip"):
|
683 |
+
# OpenCLIP style initialization
|
684 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
685 |
+
nn.init.normal_(self.pos_embed, std=0.01)
|
686 |
+
|
687 |
+
if init_param_style == "openclip":
|
688 |
+
# OpenCLIP style initialization
|
689 |
+
scale = self.embed_dim ** -0.5
|
690 |
+
if self.num_cls_tokens > 0:
|
691 |
+
nn.init.normal_(self.cls_token)
|
692 |
+
self.cls_token *= scale
|
693 |
+
elif init_param_style == "vit":
|
694 |
+
self.cls_token.data.fill_(0)
|
695 |
+
else:
|
696 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
697 |
+
|
698 |
+
def forward(self, text):
|
699 |
+
# text tokens are of shape B x L x D
|
700 |
+
text_tokens = self.token_embedding(text)
|
701 |
+
# concat CLS tokens if any
|
702 |
+
if self.num_cls_tokens > 0:
|
703 |
+
B = text_tokens.shape[0]
|
704 |
+
class_tokens = self.cls_token.expand(
|
705 |
+
B, -1, -1
|
706 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
707 |
+
text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
|
708 |
+
text_tokens = text_tokens + self.pos_embed
|
709 |
+
return_dict = {
|
710 |
+
"trunk": {
|
711 |
+
"tokens": text_tokens,
|
712 |
+
},
|
713 |
+
"head": {},
|
714 |
+
}
|
715 |
+
# Compute sequence length after adding CLS tokens
|
716 |
+
if self.supply_seq_len_to_head:
|
717 |
+
text_lengths = text.argmax(dim=-1)
|
718 |
+
return_dict["head"] = {
|
719 |
+
"seq_len": text_lengths,
|
720 |
+
}
|
721 |
+
if self.causal_masking:
|
722 |
+
return_dict["trunk"].update({"attn_mask": self.mask})
|
723 |
+
return return_dict
|
724 |
+
|
725 |
+
|
726 |
+
class Im2Video(nn.Module):
|
727 |
+
"""Convert an image into a trivial video."""
|
728 |
+
|
729 |
+
def __init__(self, time_dim=2):
|
730 |
+
super().__init__()
|
731 |
+
self.time_dim = time_dim
|
732 |
+
|
733 |
+
def forward(self, x):
|
734 |
+
if x.ndim == 4:
|
735 |
+
# B, C, H, W -> B, C, T, H, W
|
736 |
+
return x.unsqueeze(self.time_dim)
|
737 |
+
elif x.ndim == 5:
|
738 |
+
return x
|
739 |
+
else:
|
740 |
+
raise ValueError(f"Dimension incorrect {x.shape}")
|
741 |
+
|
742 |
+
|
743 |
+
class PadIm2Video(Im2Video):
|
744 |
+
def __init__(self, ntimes, pad_type, time_dim=2):
|
745 |
+
super().__init__(time_dim=time_dim)
|
746 |
+
assert ntimes > 0
|
747 |
+
assert pad_type in ["zero", "repeat"]
|
748 |
+
self.ntimes = ntimes
|
749 |
+
self.pad_type = pad_type
|
750 |
+
|
751 |
+
def forward(self, x):
|
752 |
+
x = super().forward(x)
|
753 |
+
if x.shape[self.time_dim] == 1:
|
754 |
+
if self.pad_type == "repeat":
|
755 |
+
new_shape = [1] * len(x.shape)
|
756 |
+
new_shape[self.time_dim] = self.ntimes
|
757 |
+
x = x.repeat(new_shape)
|
758 |
+
elif self.pad_type == "zero":
|
759 |
+
padarg = [0, 0] * len(x.shape)
|
760 |
+
padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
|
761 |
+
x = nn.functional.pad(x, padarg)
|
762 |
+
return x
|
763 |
+
|
764 |
+
|
765 |
+
# Modified from github.com/openai/CLIP
|
766 |
+
@lru_cache()
|
767 |
+
def bytes_to_unicode():
|
768 |
+
"""
|
769 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
770 |
+
The reversible bpe codes work on unicode strings.
|
771 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
772 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
773 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
774 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
775 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
776 |
+
"""
|
777 |
+
bs = (
|
778 |
+
list(range(ord("!"), ord("~") + 1))
|
779 |
+
+ list(range(ord("¡"), ord("¬") + 1))
|
780 |
+
+ list(range(ord("®"), ord("ÿ") + 1))
|
781 |
+
)
|
782 |
+
cs = bs[:]
|
783 |
+
n = 0
|
784 |
+
for b in range(2 ** 8):
|
785 |
+
if b not in bs:
|
786 |
+
bs.append(b)
|
787 |
+
cs.append(2 ** 8 + n)
|
788 |
+
n += 1
|
789 |
+
cs = [chr(n) for n in cs]
|
790 |
+
return dict(zip(bs, cs))
|
791 |
+
|
792 |
+
|
793 |
+
def get_pairs(word):
|
794 |
+
"""Return set of symbol pairs in a word.
|
795 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
796 |
+
"""
|
797 |
+
pairs = set()
|
798 |
+
prev_char = word[0]
|
799 |
+
for char in word[1:]:
|
800 |
+
pairs.add((prev_char, char))
|
801 |
+
prev_char = char
|
802 |
+
return pairs
|
803 |
+
|
804 |
+
|
805 |
+
def basic_clean(text):
|
806 |
+
text = ftfy.fix_text(text)
|
807 |
+
text = html.unescape(html.unescape(text))
|
808 |
+
return text.strip()
|
809 |
+
|
810 |
+
|
811 |
+
def whitespace_clean(text):
|
812 |
+
text = re.sub(r"\s+", " ", text)
|
813 |
+
text = text.strip()
|
814 |
+
return text
|
815 |
+
|
816 |
+
|
817 |
+
class SimpleTokenizer(object):
|
818 |
+
def __init__(self, bpe_path: str, context_length=77):
|
819 |
+
self.byte_encoder = bytes_to_unicode()
|
820 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
821 |
+
|
822 |
+
with g_pathmgr.open(bpe_path, "rb") as fh:
|
823 |
+
bpe_bytes = io.BytesIO(fh.read())
|
824 |
+
merges = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
|
825 |
+
merges = merges[1: 49152 - 256 - 2 + 1]
|
826 |
+
merges = [tuple(merge.split()) for merge in merges]
|
827 |
+
vocab = list(bytes_to_unicode().values())
|
828 |
+
vocab = vocab + [v + "</w>" for v in vocab]
|
829 |
+
for merge in merges:
|
830 |
+
vocab.append("".join(merge))
|
831 |
+
vocab.extend(["<|startoftext|>", "<|endoftext|>"])
|
832 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
833 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
834 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
835 |
+
self.cache = {
|
836 |
+
"<|startoftext|>": "<|startoftext|>",
|
837 |
+
"<|endoftext|>": "<|endoftext|>",
|
838 |
+
}
|
839 |
+
self.pat = re.compile(
|
840 |
+
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
841 |
+
re.IGNORECASE,
|
842 |
+
)
|
843 |
+
self.context_length = context_length
|
844 |
+
|
845 |
+
def bpe(self, token):
|
846 |
+
if token in self.cache:
|
847 |
+
return self.cache[token]
|
848 |
+
word = tuple(token[:-1]) + (token[-1] + "</w>",)
|
849 |
+
pairs = get_pairs(word)
|
850 |
+
|
851 |
+
if not pairs:
|
852 |
+
return token + "</w>"
|
853 |
+
|
854 |
+
while True:
|
855 |
+
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
856 |
+
if bigram not in self.bpe_ranks:
|
857 |
+
break
|
858 |
+
first, second = bigram
|
859 |
+
new_word = []
|
860 |
+
i = 0
|
861 |
+
while i < len(word):
|
862 |
+
try:
|
863 |
+
j = word.index(first, i)
|
864 |
+
new_word.extend(word[i:j])
|
865 |
+
i = j
|
866 |
+
except:
|
867 |
+
new_word.extend(word[i:])
|
868 |
+
break
|
869 |
+
|
870 |
+
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
871 |
+
new_word.append(first + second)
|
872 |
+
i += 2
|
873 |
+
else:
|
874 |
+
new_word.append(word[i])
|
875 |
+
i += 1
|
876 |
+
new_word = tuple(new_word)
|
877 |
+
word = new_word
|
878 |
+
if len(word) == 1:
|
879 |
+
break
|
880 |
+
else:
|
881 |
+
pairs = get_pairs(word)
|
882 |
+
word = " ".join(word)
|
883 |
+
self.cache[token] = word
|
884 |
+
return word
|
885 |
+
|
886 |
+
def encode(self, text):
|
887 |
+
bpe_tokens = []
|
888 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
889 |
+
for token in re.findall(self.pat, text):
|
890 |
+
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
891 |
+
bpe_tokens.extend(
|
892 |
+
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
893 |
+
)
|
894 |
+
return bpe_tokens
|
895 |
+
|
896 |
+
def decode(self, tokens):
|
897 |
+
text = "".join([self.decoder[token] for token in tokens])
|
898 |
+
text = (
|
899 |
+
bytearray([self.byte_decoder[c] for c in text])
|
900 |
+
.decode("utf-8", errors="replace")
|
901 |
+
.replace("</w>", " ")
|
902 |
+
)
|
903 |
+
return text
|
904 |
+
|
905 |
+
def __call__(self, texts, context_length=None):
|
906 |
+
if not context_length:
|
907 |
+
context_length = self.context_length
|
908 |
+
|
909 |
+
if isinstance(texts, str):
|
910 |
+
texts = [texts]
|
911 |
+
|
912 |
+
sot_token = self.encoder["<|startoftext|>"]
|
913 |
+
eot_token = self.encoder["<|endoftext|>"]
|
914 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
915 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
916 |
+
|
917 |
+
for i, tokens in enumerate(all_tokens):
|
918 |
+
tokens = tokens[:context_length]
|
919 |
+
result[i, : len(tokens)] = torch.tensor(tokens)
|
920 |
+
|
921 |
+
if len(result) == 1:
|
922 |
+
return result[0]
|
923 |
+
return result
|
924 |
+
|
925 |
+
|
926 |
+
class Normalize(nn.Module):
|
927 |
+
def __init__(self, dim: int) -> None:
|
928 |
+
super().__init__()
|
929 |
+
self.dim = dim
|
930 |
+
|
931 |
+
def forward(self, x):
|
932 |
+
return torch.nn.functional.normalize(x, dim=self.dim, p=2)
|
933 |
+
|
934 |
+
|
935 |
+
class LearnableLogitScaling(nn.Module):
|
936 |
+
def __init__(
|
937 |
+
self,
|
938 |
+
logit_scale_init: float = 1 / 0.07,
|
939 |
+
learnable: bool = True,
|
940 |
+
max_logit_scale: float = 100,
|
941 |
+
) -> None:
|
942 |
+
super().__init__()
|
943 |
+
self.max_logit_scale = max_logit_scale
|
944 |
+
self.logit_scale_init = logit_scale_init
|
945 |
+
self.learnable = learnable
|
946 |
+
log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
|
947 |
+
if learnable:
|
948 |
+
self.log_logit_scale = nn.Parameter(log_logit_scale)
|
949 |
+
else:
|
950 |
+
self.register_buffer("log_logit_scale", log_logit_scale)
|
951 |
+
|
952 |
+
def forward(self, x):
|
953 |
+
return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
|
954 |
+
|
955 |
+
def extra_repr(self):
|
956 |
+
st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
|
957 |
+
return st
|
958 |
+
|
959 |
+
|
960 |
+
class EinOpsRearrange(nn.Module):
|
961 |
+
def __init__(self, rearrange_expr: str, **kwargs) -> None:
|
962 |
+
super().__init__()
|
963 |
+
self.rearrange_expr = rearrange_expr
|
964 |
+
self.kwargs = kwargs
|
965 |
+
|
966 |
+
def forward(self, x):
|
967 |
+
assert isinstance(x, torch.Tensor)
|
968 |
+
return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
|
969 |
+
|
970 |
+
|
971 |
+
class IMUPreprocessor(VerboseNNModule):
|
972 |
+
def __init__(
|
973 |
+
self,
|
974 |
+
kernel_size: int,
|
975 |
+
imu_stem: PatchEmbedGeneric,
|
976 |
+
embed_dim: int,
|
977 |
+
img_size: List = (6, 2000),
|
978 |
+
num_cls_tokens: int = 1,
|
979 |
+
pos_embed_fn: Callable = None,
|
980 |
+
init_param_style: str = "openclip",
|
981 |
+
) -> None:
|
982 |
+
super().__init__()
|
983 |
+
stem = imu_stem
|
984 |
+
self.imu_stem = imu_stem
|
985 |
+
self.embed_dim = embed_dim
|
986 |
+
self.use_pos_embed = pos_embed_fn is not None
|
987 |
+
self.num_cls_tokens = num_cls_tokens
|
988 |
+
self.kernel_size = kernel_size
|
989 |
+
self.pos_embed = nn.Parameter(
|
990 |
+
torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
|
991 |
+
)
|
992 |
+
|
993 |
+
if self.num_cls_tokens > 0:
|
994 |
+
self.cls_token = nn.Parameter(
|
995 |
+
torch.zeros(1, self.num_cls_tokens, self.embed_dim)
|
996 |
+
)
|
997 |
+
|
998 |
+
self.init_parameters(init_param_style)
|
999 |
+
|
1000 |
+
@torch.no_grad()
|
1001 |
+
def init_parameters(self, init_param_style):
|
1002 |
+
nn.init.normal_(self.pos_embed, std=0.01)
|
1003 |
+
|
1004 |
+
if init_param_style == "openclip":
|
1005 |
+
# OpenCLIP style initialization
|
1006 |
+
scale = self.embed_dim ** -0.5
|
1007 |
+
|
1008 |
+
if self.num_cls_tokens > 0:
|
1009 |
+
nn.init.normal_(self.cls_token)
|
1010 |
+
self.cls_token *= scale
|
1011 |
+
elif init_param_style == "vit":
|
1012 |
+
self.cls_token.data.fill_(0)
|
1013 |
+
else:
|
1014 |
+
raise ValueError(f"Unknown init {init_param_style}")
|
1015 |
+
|
1016 |
+
def tokenize_input_and_cls_pos(self, input, stem):
|
1017 |
+
# tokens is of shape B x L x D
|
1018 |
+
tokens = stem.norm_layer(stem.proj(input))
|
1019 |
+
assert tokens.ndim == 3
|
1020 |
+
assert tokens.shape[2] == self.embed_dim
|
1021 |
+
B = tokens.shape[0]
|
1022 |
+
if self.num_cls_tokens > 0:
|
1023 |
+
class_tokens = self.cls_token.expand(
|
1024 |
+
B, -1, -1
|
1025 |
+
) # stole class_tokens impl from Phil Wang, thanks
|
1026 |
+
tokens = torch.cat((class_tokens, tokens), dim=1)
|
1027 |
+
if self.use_pos_embed:
|
1028 |
+
tokens = tokens + self.pos_embed
|
1029 |
+
return tokens
|
1030 |
+
|
1031 |
+
def forward(self, imu):
|
1032 |
+
# Patchify
|
1033 |
+
imu = imu.unfold(
|
1034 |
+
-1,
|
1035 |
+
self.kernel_size,
|
1036 |
+
self.kernel_size,
|
1037 |
+
).permute(0, 2, 1, 3)
|
1038 |
+
imu = imu.reshape(imu.size(0), imu.size(1), -1)
|
1039 |
+
|
1040 |
+
imu_tokens = self.tokenize_input_and_cls_pos(
|
1041 |
+
imu,
|
1042 |
+
self.imu_stem,
|
1043 |
+
)
|
1044 |
+
|
1045 |
+
return_dict = {
|
1046 |
+
"trunk": {
|
1047 |
+
"tokens": imu_tokens,
|
1048 |
+
},
|
1049 |
+
"head": {},
|
1050 |
+
}
|
1051 |
+
return return_dict
|
1052 |
+
|
1053 |
+
|
1054 |
+
def cast_if_src_dtype(
|
1055 |
+
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
|
1056 |
+
):
|
1057 |
+
updated = False
|
1058 |
+
if tensor.dtype == src_dtype:
|
1059 |
+
tensor = tensor.to(dtype=tgt_dtype)
|
1060 |
+
updated = True
|
1061 |
+
return tensor, updated
|
1062 |
+
|
1063 |
+
|
1064 |
+
class QuickGELU(nn.Module):
|
1065 |
+
# From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
|
1066 |
+
def forward(self, x: torch.Tensor):
|
1067 |
+
return x * torch.sigmoid(1.702 * x)
|
1068 |
+
|
1069 |
+
|
1070 |
+
class SelectElement(nn.Module):
|
1071 |
+
def __init__(self, index) -> None:
|
1072 |
+
super().__init__()
|
1073 |
+
self.index = index
|
1074 |
+
|
1075 |
+
def forward(self, x):
|
1076 |
+
assert x.ndim >= 3
|
1077 |
+
return x[:, self.index, ...]
|
1078 |
+
|
1079 |
+
|
1080 |
+
class ReshapeSpatial(nn.Module):
|
1081 |
+
def __init__(self, shape) -> None:
|
1082 |
+
super().__init__()
|
1083 |
+
self.h, self.w = shape
|
1084 |
+
|
1085 |
+
def forward(self, x):
|
1086 |
+
assert x.ndim >= 3
|
1087 |
+
return x[:, 1:, ...].reshape(x.shape[0], self.h, self.w, -1), x[:, 0, :]
|
1088 |
+
|
1089 |
+
|
1090 |
+
class ReshapeAudio(nn.Module):
|
1091 |
+
def __init__(self, shape) -> None:
|
1092 |
+
super().__init__()
|
1093 |
+
self.h, self.w = shape
|
1094 |
+
|
1095 |
+
def forward(self, x):
|
1096 |
+
assert x.ndim == 3
|
1097 |
+
return x[:, 1:, :].reshape(-1, 5, self.h, self.w, x.shape[-1]), x[:, 0, :]
|
1098 |
+
|
1099 |
+
|
1100 |
+
class ApplyTwice(nn.Module):
|
1101 |
+
def __init__(self, module) -> None:
|
1102 |
+
super().__init__()
|
1103 |
+
self.module = module
|
1104 |
+
|
1105 |
+
def forward(self, pair):
|
1106 |
+
return self.module(pair[0]), self.module(pair[1])
|
1107 |
+
|
1108 |
+
|
1109 |
+
class SelectEOSAndProject(nn.Module):
|
1110 |
+
"""
|
1111 |
+
Text Pooling used in OpenCLIP
|
1112 |
+
"""
|
1113 |
+
|
1114 |
+
def __init__(self, proj: nn.Module) -> None:
|
1115 |
+
super().__init__()
|
1116 |
+
self.proj = proj
|
1117 |
+
|
1118 |
+
def forward(self, x, seq_len):
|
1119 |
+
assert x.ndim == 3
|
1120 |
+
# x is of shape B x L x D
|
1121 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
1122 |
+
x = x[torch.arange(x.shape[0]), seq_len]
|
1123 |
+
x = self.proj(x)
|
1124 |
+
return x
|
1125 |
+
|
1126 |
+
|
1127 |
+
ModalityType = SimpleNamespace(
|
1128 |
+
VISION="vision",
|
1129 |
+
TEXT="text",
|
1130 |
+
AUDIO="audio",
|
1131 |
+
THERMAL="thermal",
|
1132 |
+
DEPTH="depth",
|
1133 |
+
IMU="imu",
|
1134 |
+
)
|
1135 |
+
|
1136 |
+
|
1137 |
+
class ImageBindModel(nn.Module):
|
1138 |
+
def __init__(
|
1139 |
+
self,
|
1140 |
+
video_frames=2,
|
1141 |
+
kernel_size=(2, 14, 14),
|
1142 |
+
audio_kernel_size=16,
|
1143 |
+
audio_stride=10,
|
1144 |
+
out_embed_dim=768,
|
1145 |
+
vision_embed_dim=1024,
|
1146 |
+
vision_num_blocks=24,
|
1147 |
+
vision_num_heads=16,
|
1148 |
+
audio_embed_dim=768,
|
1149 |
+
audio_num_blocks=12,
|
1150 |
+
audio_num_heads=12,
|
1151 |
+
audio_num_mel_bins=128,
|
1152 |
+
audio_target_len=204,
|
1153 |
+
audio_drop_path=0.1,
|
1154 |
+
text_embed_dim=768,
|
1155 |
+
text_num_blocks=12,
|
1156 |
+
text_num_heads=12,
|
1157 |
+
depth_embed_dim=384,
|
1158 |
+
depth_kernel_size=16,
|
1159 |
+
depth_num_blocks=12,
|
1160 |
+
depth_num_heads=8,
|
1161 |
+
depth_drop_path=0.0,
|
1162 |
+
thermal_embed_dim=768,
|
1163 |
+
thermal_kernel_size=16,
|
1164 |
+
thermal_num_blocks=12,
|
1165 |
+
thermal_num_heads=12,
|
1166 |
+
thermal_drop_path=0.0,
|
1167 |
+
imu_embed_dim=512,
|
1168 |
+
imu_kernel_size=8,
|
1169 |
+
imu_num_blocks=6,
|
1170 |
+
imu_num_heads=8,
|
1171 |
+
imu_drop_path=0.7,
|
1172 |
+
):
|
1173 |
+
super().__init__()
|
1174 |
+
|
1175 |
+
self.modality_preprocessors = self._create_modality_preprocessors(
|
1176 |
+
video_frames,
|
1177 |
+
vision_embed_dim,
|
1178 |
+
kernel_size,
|
1179 |
+
text_embed_dim,
|
1180 |
+
audio_embed_dim,
|
1181 |
+
audio_kernel_size,
|
1182 |
+
audio_stride,
|
1183 |
+
audio_num_mel_bins,
|
1184 |
+
audio_target_len,
|
1185 |
+
depth_embed_dim,
|
1186 |
+
depth_kernel_size,
|
1187 |
+
thermal_embed_dim,
|
1188 |
+
thermal_kernel_size,
|
1189 |
+
imu_embed_dim,
|
1190 |
+
)
|
1191 |
+
|
1192 |
+
self.modality_trunks = self._create_modality_trunks(
|
1193 |
+
vision_embed_dim,
|
1194 |
+
vision_num_blocks,
|
1195 |
+
vision_num_heads,
|
1196 |
+
text_embed_dim,
|
1197 |
+
text_num_blocks,
|
1198 |
+
text_num_heads,
|
1199 |
+
audio_embed_dim,
|
1200 |
+
audio_num_blocks,
|
1201 |
+
audio_num_heads,
|
1202 |
+
audio_drop_path,
|
1203 |
+
depth_embed_dim,
|
1204 |
+
depth_num_blocks,
|
1205 |
+
depth_num_heads,
|
1206 |
+
depth_drop_path,
|
1207 |
+
thermal_embed_dim,
|
1208 |
+
thermal_num_blocks,
|
1209 |
+
thermal_num_heads,
|
1210 |
+
thermal_drop_path,
|
1211 |
+
imu_embed_dim,
|
1212 |
+
imu_num_blocks,
|
1213 |
+
imu_num_heads,
|
1214 |
+
imu_drop_path,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
self.modality_heads = self._create_modality_heads(
|
1218 |
+
out_embed_dim,
|
1219 |
+
vision_embed_dim,
|
1220 |
+
text_embed_dim,
|
1221 |
+
audio_embed_dim,
|
1222 |
+
depth_embed_dim,
|
1223 |
+
thermal_embed_dim,
|
1224 |
+
imu_embed_dim,
|
1225 |
+
)
|
1226 |
+
|
1227 |
+
self.modality_postprocessors = self._create_modality_postprocessors(
|
1228 |
+
out_embed_dim
|
1229 |
+
)
|
1230 |
+
|
1231 |
+
def _create_modality_preprocessors(
|
1232 |
+
self,
|
1233 |
+
video_frames=2,
|
1234 |
+
vision_embed_dim=1024,
|
1235 |
+
kernel_size=(2, 14, 14),
|
1236 |
+
text_embed_dim=768,
|
1237 |
+
audio_embed_dim=768,
|
1238 |
+
audio_kernel_size=16,
|
1239 |
+
audio_stride=10,
|
1240 |
+
audio_num_mel_bins=128,
|
1241 |
+
audio_target_len=204,
|
1242 |
+
depth_embed_dim=768,
|
1243 |
+
depth_kernel_size=16,
|
1244 |
+
thermal_embed_dim=768,
|
1245 |
+
thermal_kernel_size=16,
|
1246 |
+
imu_embed_dim=512,
|
1247 |
+
):
|
1248 |
+
rgbt_stem = PatchEmbedGeneric(
|
1249 |
+
proj_stem=[
|
1250 |
+
PadIm2Video(pad_type="repeat", ntimes=2),
|
1251 |
+
nn.Conv3d(
|
1252 |
+
in_channels=3,
|
1253 |
+
kernel_size=kernel_size,
|
1254 |
+
out_channels=vision_embed_dim,
|
1255 |
+
stride=kernel_size,
|
1256 |
+
bias=False,
|
1257 |
+
),
|
1258 |
+
]
|
1259 |
+
)
|
1260 |
+
rgbt_preprocessor = RGBDTPreprocessor(
|
1261 |
+
img_size=[3, video_frames, 224, 224],
|
1262 |
+
num_cls_tokens=1,
|
1263 |
+
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
|
1264 |
+
rgbt_stem=rgbt_stem,
|
1265 |
+
depth_stem=None,
|
1266 |
+
)
|
1267 |
+
|
1268 |
+
text_preprocessor = TextPreprocessor(
|
1269 |
+
context_length=77,
|
1270 |
+
vocab_size=49408,
|
1271 |
+
embed_dim=text_embed_dim,
|
1272 |
+
causal_masking=True,
|
1273 |
+
)
|
1274 |
+
|
1275 |
+
audio_stem = PatchEmbedGeneric(
|
1276 |
+
proj_stem=[
|
1277 |
+
nn.Conv2d(
|
1278 |
+
in_channels=1,
|
1279 |
+
kernel_size=audio_kernel_size,
|
1280 |
+
stride=audio_stride,
|
1281 |
+
out_channels=audio_embed_dim,
|
1282 |
+
bias=False,
|
1283 |
+
),
|
1284 |
+
],
|
1285 |
+
norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
|
1286 |
+
)
|
1287 |
+
audio_preprocessor = AudioPreprocessor(
|
1288 |
+
img_size=[1, audio_num_mel_bins, audio_target_len],
|
1289 |
+
num_cls_tokens=1,
|
1290 |
+
pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
|
1291 |
+
audio_stem=audio_stem,
|
1292 |
+
)
|
1293 |
+
|
1294 |
+
# depth_stem = PatchEmbedGeneric(
|
1295 |
+
# [
|
1296 |
+
# nn.Conv2d(
|
1297 |
+
# kernel_size=depth_kernel_size,
|
1298 |
+
# in_channels=1,
|
1299 |
+
# out_channels=depth_embed_dim,
|
1300 |
+
# stride=depth_kernel_size,
|
1301 |
+
# bias=False,
|
1302 |
+
# ),
|
1303 |
+
# ],
|
1304 |
+
# norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
|
1305 |
+
# )
|
1306 |
+
#
|
1307 |
+
# depth_preprocessor = RGBDTPreprocessor(
|
1308 |
+
# img_size=[1, 224, 224],
|
1309 |
+
# num_cls_tokens=1,
|
1310 |
+
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
|
1311 |
+
# rgbt_stem=None,
|
1312 |
+
# depth_stem=depth_stem,
|
1313 |
+
# )
|
1314 |
+
#
|
1315 |
+
# thermal_stem = PatchEmbedGeneric(
|
1316 |
+
# [
|
1317 |
+
# nn.Conv2d(
|
1318 |
+
# kernel_size=thermal_kernel_size,
|
1319 |
+
# in_channels=1,
|
1320 |
+
# out_channels=thermal_embed_dim,
|
1321 |
+
# stride=thermal_kernel_size,
|
1322 |
+
# bias=False,
|
1323 |
+
# ),
|
1324 |
+
# ],
|
1325 |
+
# norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
|
1326 |
+
# )
|
1327 |
+
# thermal_preprocessor = ThermalPreprocessor(
|
1328 |
+
# img_size=[1, 224, 224],
|
1329 |
+
# num_cls_tokens=1,
|
1330 |
+
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
|
1331 |
+
# thermal_stem=thermal_stem,
|
1332 |
+
# )
|
1333 |
+
#
|
1334 |
+
# imu_stem = PatchEmbedGeneric(
|
1335 |
+
# [
|
1336 |
+
# nn.Linear(
|
1337 |
+
# in_features=48,
|
1338 |
+
# out_features=imu_embed_dim,
|
1339 |
+
# bias=False,
|
1340 |
+
# ),
|
1341 |
+
# ],
|
1342 |
+
# norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
|
1343 |
+
# )
|
1344 |
+
#
|
1345 |
+
# imu_preprocessor = IMUPreprocessor(
|
1346 |
+
# img_size=[6, 2000],
|
1347 |
+
# num_cls_tokens=1,
|
1348 |
+
# kernel_size=8,
|
1349 |
+
# embed_dim=imu_embed_dim,
|
1350 |
+
# pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
|
1351 |
+
# imu_stem=imu_stem,
|
1352 |
+
# )
|
1353 |
+
|
1354 |
+
modality_preprocessors = {
|
1355 |
+
ModalityType.VISION: rgbt_preprocessor,
|
1356 |
+
ModalityType.TEXT: text_preprocessor,
|
1357 |
+
ModalityType.AUDIO: audio_preprocessor,
|
1358 |
+
# ModalityType.DEPTH: depth_preprocessor,
|
1359 |
+
# ModalityType.THERMAL: thermal_preprocessor,
|
1360 |
+
# ModalityType.IMU: imu_preprocessor,
|
1361 |
+
}
|
1362 |
+
|
1363 |
+
return nn.ModuleDict(modality_preprocessors)
|
1364 |
+
|
1365 |
+
def _create_modality_trunks(
|
1366 |
+
self,
|
1367 |
+
vision_embed_dim=1024,
|
1368 |
+
vision_num_blocks=24,
|
1369 |
+
vision_num_heads=16,
|
1370 |
+
text_embed_dim=768,
|
1371 |
+
text_num_blocks=12,
|
1372 |
+
text_num_heads=12,
|
1373 |
+
audio_embed_dim=768,
|
1374 |
+
audio_num_blocks=12,
|
1375 |
+
audio_num_heads=12,
|
1376 |
+
audio_drop_path=0.0,
|
1377 |
+
depth_embed_dim=768,
|
1378 |
+
depth_num_blocks=12,
|
1379 |
+
depth_num_heads=12,
|
1380 |
+
depth_drop_path=0.0,
|
1381 |
+
thermal_embed_dim=768,
|
1382 |
+
thermal_num_blocks=12,
|
1383 |
+
thermal_num_heads=12,
|
1384 |
+
thermal_drop_path=0.0,
|
1385 |
+
imu_embed_dim=512,
|
1386 |
+
imu_num_blocks=6,
|
1387 |
+
imu_num_heads=8,
|
1388 |
+
imu_drop_path=0.7,
|
1389 |
+
):
|
1390 |
+
def instantiate_trunk(
|
1391 |
+
embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
|
1392 |
+
):
|
1393 |
+
return SimpleTransformer(
|
1394 |
+
embed_dim=embed_dim,
|
1395 |
+
num_blocks=num_blocks,
|
1396 |
+
ffn_dropout_rate=0.0,
|
1397 |
+
drop_path_rate=drop_path,
|
1398 |
+
attn_target=partial(
|
1399 |
+
MultiheadAttention,
|
1400 |
+
embed_dim=embed_dim,
|
1401 |
+
num_heads=num_heads,
|
1402 |
+
bias=True,
|
1403 |
+
add_bias_kv=add_bias_kv,
|
1404 |
+
),
|
1405 |
+
pre_transformer_layer=nn.Sequential(
|
1406 |
+
nn.LayerNorm(embed_dim, eps=1e-6)
|
1407 |
+
if pre_transformer_ln
|
1408 |
+
else nn.Identity(),
|
1409 |
+
EinOpsRearrange("b l d -> l b d"),
|
1410 |
+
),
|
1411 |
+
post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
|
1412 |
+
)
|
1413 |
+
|
1414 |
+
modality_trunks = {}
|
1415 |
+
modality_trunks[ModalityType.VISION] = instantiate_trunk(
|
1416 |
+
vision_embed_dim,
|
1417 |
+
vision_num_blocks,
|
1418 |
+
vision_num_heads,
|
1419 |
+
pre_transformer_ln=True,
|
1420 |
+
add_bias_kv=False,
|
1421 |
+
drop_path=0.0,
|
1422 |
+
)
|
1423 |
+
modality_trunks[ModalityType.TEXT] = instantiate_trunk(
|
1424 |
+
text_embed_dim,
|
1425 |
+
text_num_blocks,
|
1426 |
+
text_num_heads,
|
1427 |
+
pre_transformer_ln=False,
|
1428 |
+
add_bias_kv=False,
|
1429 |
+
drop_path=0.0,
|
1430 |
+
)
|
1431 |
+
modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
|
1432 |
+
audio_embed_dim,
|
1433 |
+
audio_num_blocks,
|
1434 |
+
audio_num_heads,
|
1435 |
+
pre_transformer_ln=False,
|
1436 |
+
add_bias_kv=True,
|
1437 |
+
drop_path=audio_drop_path,
|
1438 |
+
)
|
1439 |
+
# modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
|
1440 |
+
# depth_embed_dim,
|
1441 |
+
# depth_num_blocks,
|
1442 |
+
# depth_num_heads,
|
1443 |
+
# pre_transformer_ln=False,
|
1444 |
+
# add_bias_kv=True,
|
1445 |
+
# drop_path=depth_drop_path,
|
1446 |
+
# )
|
1447 |
+
# modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
|
1448 |
+
# thermal_embed_dim,
|
1449 |
+
# thermal_num_blocks,
|
1450 |
+
# thermal_num_heads,
|
1451 |
+
# pre_transformer_ln=False,
|
1452 |
+
# add_bias_kv=True,
|
1453 |
+
# drop_path=thermal_drop_path,
|
1454 |
+
# )
|
1455 |
+
# modality_trunks[ModalityType.IMU] = instantiate_trunk(
|
1456 |
+
# imu_embed_dim,
|
1457 |
+
# imu_num_blocks,
|
1458 |
+
# imu_num_heads,
|
1459 |
+
# pre_transformer_ln=False,
|
1460 |
+
# add_bias_kv=True,
|
1461 |
+
# drop_path=imu_drop_path,
|
1462 |
+
# )
|
1463 |
+
|
1464 |
+
return nn.ModuleDict(modality_trunks)
|
1465 |
+
|
1466 |
+
def _create_modality_heads(
|
1467 |
+
self,
|
1468 |
+
out_embed_dim,
|
1469 |
+
vision_embed_dim,
|
1470 |
+
text_embed_dim,
|
1471 |
+
audio_embed_dim,
|
1472 |
+
depth_embed_dim,
|
1473 |
+
thermal_embed_dim,
|
1474 |
+
imu_embed_dim,
|
1475 |
+
):
|
1476 |
+
modality_heads = {}
|
1477 |
+
|
1478 |
+
modality_heads[ModalityType.VISION] = nn.Sequential(
|
1479 |
+
nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
|
1480 |
+
SelectElement(index=0),
|
1481 |
+
nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
|
1482 |
+
)
|
1483 |
+
|
1484 |
+
modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
|
1485 |
+
proj=nn.Sequential(
|
1486 |
+
nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
|
1487 |
+
nn.Linear(text_embed_dim, out_embed_dim, bias=False),
|
1488 |
+
)
|
1489 |
+
)
|
1490 |
+
|
1491 |
+
modality_heads[ModalityType.AUDIO] = nn.Sequential(
|
1492 |
+
nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
|
1493 |
+
SelectElement(index=0),
|
1494 |
+
nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
|
1495 |
+
)
|
1496 |
+
|
1497 |
+
# modality_heads[ModalityType.DEPTH] = nn.Sequential(
|
1498 |
+
# nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
|
1499 |
+
# SelectElement(index=0),
|
1500 |
+
# nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
|
1501 |
+
# )
|
1502 |
+
#
|
1503 |
+
# modality_heads[ModalityType.THERMAL] = nn.Sequential(
|
1504 |
+
# nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
|
1505 |
+
# SelectElement(index=0),
|
1506 |
+
# nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
|
1507 |
+
# )
|
1508 |
+
#
|
1509 |
+
# modality_heads[ModalityType.IMU] = nn.Sequential(
|
1510 |
+
# nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
|
1511 |
+
# SelectElement(index=0),
|
1512 |
+
# nn.Dropout(p=0.5),
|
1513 |
+
# nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
|
1514 |
+
# )
|
1515 |
+
|
1516 |
+
return nn.ModuleDict(modality_heads)
|
1517 |
+
|
1518 |
+
def _create_modality_postprocessors(self, out_embed_dim):
|
1519 |
+
modality_postprocessors = {}
|
1520 |
+
|
1521 |
+
modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
|
1522 |
+
modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
|
1523 |
+
Normalize(dim=-1), LearnableLogitScaling(learnable=True)
|
1524 |
+
)
|
1525 |
+
modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
|
1526 |
+
Normalize(dim=-1),
|
1527 |
+
LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
|
1528 |
+
)
|
1529 |
+
# modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
|
1530 |
+
# Normalize(dim=-1),
|
1531 |
+
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
|
1532 |
+
# )
|
1533 |
+
# modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
|
1534 |
+
# Normalize(dim=-1),
|
1535 |
+
# LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
|
1536 |
+
# )
|
1537 |
+
# modality_postprocessors[ModalityType.IMU] = nn.Sequential(
|
1538 |
+
# Normalize(dim=-1),
|
1539 |
+
# LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
|
1540 |
+
# )
|
1541 |
+
|
1542 |
+
return nn.ModuleDict(modality_postprocessors)
|
1543 |
+
|
1544 |
+
def forward(self, inputs):
|
1545 |
+
outputs = {}
|
1546 |
+
for modality_key, modality_value in inputs.items():
|
1547 |
+
reduce_list = (
|
1548 |
+
modality_value.ndim >= 5
|
1549 |
+
) # Audio and Video inputs consist of multiple clips
|
1550 |
+
if reduce_list:
|
1551 |
+
B, S = modality_value.shape[:2]
|
1552 |
+
modality_value = modality_value.reshape(
|
1553 |
+
B * S, *modality_value.shape[2:]
|
1554 |
+
)
|
1555 |
+
|
1556 |
+
if modality_value is not None:
|
1557 |
+
modality_value = self.modality_preprocessors[modality_key](
|
1558 |
+
**{modality_key: modality_value}
|
1559 |
+
)
|
1560 |
+
trunk_inputs = modality_value["trunk"]
|
1561 |
+
head_inputs = modality_value["head"]
|
1562 |
+
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
|
1563 |
+
modality_value = self.modality_heads[modality_key](
|
1564 |
+
modality_value, **head_inputs
|
1565 |
+
)
|
1566 |
+
modality_value = self.modality_postprocessors[modality_key](
|
1567 |
+
modality_value
|
1568 |
+
)
|
1569 |
+
|
1570 |
+
if reduce_list:
|
1571 |
+
modality_value = modality_value.reshape(B, S, -1)
|
1572 |
+
modality_value = modality_value.mean(dim=1)
|
1573 |
+
|
1574 |
+
outputs[modality_key] = modality_value
|
1575 |
+
|
1576 |
+
return outputs
|
1577 |
+
|
1578 |
+
def reconfigure_head(self, k, v):
|
1579 |
+
if k == ModalityType.AUDIO:
|
1580 |
+
return torch.nn.Sequential(v[0], v[2])
|
1581 |
+
elif k == ModalityType.VISION:
|
1582 |
+
return torch.nn.Sequential(v[0], v[2])
|
1583 |
+
else:
|
1584 |
+
return v
|
1585 |
+
|
1586 |
+
def forward_features(self, inputs):
|
1587 |
+
outputs = {}
|
1588 |
+
|
1589 |
+
reconfigured_heads = {k: self.reconfigure_head(k, v) for k, v in self.modality_heads.items()}
|
1590 |
+
|
1591 |
+
for modality_key, modality_value in inputs.items():
|
1592 |
+
reduce_list = (
|
1593 |
+
modality_value.ndim >= 5
|
1594 |
+
) # Audio and Video inputs consist of multiple clips
|
1595 |
+
if reduce_list:
|
1596 |
+
B, S = modality_value.shape[:2]
|
1597 |
+
modality_value = modality_value.reshape(
|
1598 |
+
B * S, *modality_value.shape[2:]
|
1599 |
+
)
|
1600 |
+
|
1601 |
+
if modality_value is not None:
|
1602 |
+
modality_value = self.modality_preprocessors[modality_key](
|
1603 |
+
**{modality_key: modality_value}
|
1604 |
+
)
|
1605 |
+
trunk_inputs = modality_value["trunk"]
|
1606 |
+
head_inputs = modality_value["head"]
|
1607 |
+
modality_value = self.modality_trunks[modality_key](**trunk_inputs)
|
1608 |
+
modality_value = reconfigured_heads[modality_key](
|
1609 |
+
modality_value, **head_inputs
|
1610 |
+
)
|
1611 |
+
modality_value = self.modality_postprocessors[modality_key](
|
1612 |
+
modality_value
|
1613 |
+
)
|
1614 |
+
if modality_key == ModalityType.AUDIO:
|
1615 |
+
modality_value = ReshapeAudio((12, 19))(modality_value)
|
1616 |
+
elif modality_key == ModalityType.VISION:
|
1617 |
+
modality_value = ReshapeSpatial((16, 16))(modality_value)
|
1618 |
+
|
1619 |
+
outputs[modality_key] = modality_value
|
1620 |
+
|
1621 |
+
return outputs
|
1622 |
+
|
1623 |
+
|
1624 |
+
def imagebind_huge(output_path, pretrained=False):
|
1625 |
+
model = ImageBindModel(
|
1626 |
+
vision_embed_dim=1280,
|
1627 |
+
vision_num_blocks=32,
|
1628 |
+
vision_num_heads=16,
|
1629 |
+
text_embed_dim=1024,
|
1630 |
+
text_num_blocks=24,
|
1631 |
+
text_num_heads=16,
|
1632 |
+
out_embed_dim=1024,
|
1633 |
+
audio_drop_path=0.1,
|
1634 |
+
imu_drop_path=0.7,
|
1635 |
+
)
|
1636 |
+
|
1637 |
+
if pretrained:
|
1638 |
+
path = os.path.join(output_path, 'models/imagebind_huge.pth')
|
1639 |
+
|
1640 |
+
if not os.path.exists(path):
|
1641 |
+
print(f"Downloading imagebind weights to {path} ...")
|
1642 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
1643 |
+
torch.hub.download_url_to_file(
|
1644 |
+
"https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
|
1645 |
+
path,
|
1646 |
+
progress=True,
|
1647 |
+
)
|
1648 |
+
|
1649 |
+
model.load_state_dict(torch.load(path), strict=False)
|
1650 |
+
|
1651 |
+
return model
|
1652 |
+
|
1653 |
+
|
1654 |
+
DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
|
1655 |
+
|
1656 |
+
|
1657 |
+
def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
|
1658 |
+
# Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
|
1659 |
+
waveform -= waveform.mean()
|
1660 |
+
fbank = torchaudio.compliance.kaldi.fbank(
|
1661 |
+
waveform,
|
1662 |
+
htk_compat=True,
|
1663 |
+
sample_frequency=sample_rate,
|
1664 |
+
use_energy=False,
|
1665 |
+
window_type="hanning",
|
1666 |
+
num_mel_bins=num_mel_bins,
|
1667 |
+
dither=0.0,
|
1668 |
+
frame_length=25,
|
1669 |
+
frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
|
1670 |
+
)
|
1671 |
+
# Convert to [mel_bins, num_frames] shape
|
1672 |
+
fbank = fbank.transpose(0, 1)
|
1673 |
+
# Pad to target_length
|
1674 |
+
n_frames = fbank.size(1)
|
1675 |
+
p = target_length - n_frames
|
1676 |
+
# if p is too large (say >20%), flash a warning
|
1677 |
+
if abs(p) / n_frames > 0.2:
|
1678 |
+
logging.warning(
|
1679 |
+
"Large gap between audio n_frames(%d) and "
|
1680 |
+
"target_length (%d). Is the audio_target_length "
|
1681 |
+
"setting correct?",
|
1682 |
+
n_frames,
|
1683 |
+
target_length,
|
1684 |
+
)
|
1685 |
+
# cut and pad
|
1686 |
+
if p > 0:
|
1687 |
+
fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
|
1688 |
+
elif p < 0:
|
1689 |
+
fbank = fbank[:, 0:target_length]
|
1690 |
+
# Convert to [1, mel_bins, num_frames] shape, essentially like a 1
|
1691 |
+
# channel image
|
1692 |
+
fbank = fbank.unsqueeze(0)
|
1693 |
+
return fbank
|
1694 |
+
|
1695 |
+
|
1696 |
+
def get_clip_timepoints(clip_sampler, duration):
|
1697 |
+
# Read out all clips in this video
|
1698 |
+
all_clips_timepoints = []
|
1699 |
+
is_last_clip = False
|
1700 |
+
end = 0.0
|
1701 |
+
while not is_last_clip:
|
1702 |
+
start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
|
1703 |
+
all_clips_timepoints.append((start, end))
|
1704 |
+
return all_clips_timepoints
|
1705 |
+
|
1706 |
+
|
1707 |
+
def load_and_transform_vision_data(image_paths, device):
|
1708 |
+
if image_paths is None:
|
1709 |
+
return None
|
1710 |
+
|
1711 |
+
image_ouputs = []
|
1712 |
+
for image_path in image_paths:
|
1713 |
+
data_transform = transforms.Compose(
|
1714 |
+
[
|
1715 |
+
transforms.Resize(
|
1716 |
+
224, interpolation=transforms.InterpolationMode.BICUBIC
|
1717 |
+
),
|
1718 |
+
transforms.CenterCrop(224),
|
1719 |
+
transforms.ToTensor(),
|
1720 |
+
transforms.Normalize(
|
1721 |
+
mean=(0.48145466, 0.4578275, 0.40821073),
|
1722 |
+
std=(0.26862954, 0.26130258, 0.27577711),
|
1723 |
+
),
|
1724 |
+
]
|
1725 |
+
)
|
1726 |
+
with open(image_path, "rb") as fopen:
|
1727 |
+
image = Image.open(fopen).convert("RGB")
|
1728 |
+
|
1729 |
+
image = data_transform(image).to(device)
|
1730 |
+
image_ouputs.append(image)
|
1731 |
+
return torch.stack(image_ouputs, dim=0)
|
1732 |
+
|
1733 |
+
|
1734 |
+
def load_and_transform_audio_data(
|
1735 |
+
audio_paths,
|
1736 |
+
device,
|
1737 |
+
num_mel_bins=128,
|
1738 |
+
target_length=204,
|
1739 |
+
sample_rate=16000,
|
1740 |
+
clip_duration=2,
|
1741 |
+
clips_per_video=3,
|
1742 |
+
mean=-4.268,
|
1743 |
+
std=9.138,
|
1744 |
+
):
|
1745 |
+
from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
|
1746 |
+
|
1747 |
+
if audio_paths is None:
|
1748 |
+
return None
|
1749 |
+
|
1750 |
+
audio_outputs = []
|
1751 |
+
clip_sampler = ConstantClipsPerVideoSampler(
|
1752 |
+
clip_duration=clip_duration, clips_per_video=clips_per_video
|
1753 |
+
)
|
1754 |
+
|
1755 |
+
for audio_path in audio_paths:
|
1756 |
+
waveform, sr = torchaudio.load(audio_path)
|
1757 |
+
if sample_rate != sr:
|
1758 |
+
waveform = torchaudio.functional.resample(
|
1759 |
+
waveform, orig_freq=sr, new_freq=sample_rate
|
1760 |
+
)
|
1761 |
+
all_clips_timepoints = get_clip_timepoints(
|
1762 |
+
clip_sampler, waveform.size(1) / sample_rate
|
1763 |
+
)
|
1764 |
+
all_clips = []
|
1765 |
+
for clip_timepoints in all_clips_timepoints:
|
1766 |
+
waveform_clip = waveform[
|
1767 |
+
:,
|
1768 |
+
int(clip_timepoints[0] * sample_rate): int(
|
1769 |
+
clip_timepoints[1] * sample_rate
|
1770 |
+
),
|
1771 |
+
]
|
1772 |
+
waveform_melspec = waveform2melspec(
|
1773 |
+
waveform_clip, sample_rate, num_mel_bins, target_length
|
1774 |
+
)
|
1775 |
+
all_clips.append(waveform_melspec)
|
1776 |
+
|
1777 |
+
normalize = transforms.Normalize(mean=mean, std=std)
|
1778 |
+
all_clips = [normalize(ac).to(device) for ac in all_clips]
|
1779 |
+
|
1780 |
+
all_clips = torch.stack(all_clips, dim=0)
|
1781 |
+
audio_outputs.append(all_clips)
|
1782 |
+
|
1783 |
+
return torch.stack(audio_outputs, dim=0)
|
1784 |
+
|
1785 |
+
|
1786 |
+
class UnNormalize(object):
|
1787 |
+
def __init__(self, mean, std):
|
1788 |
+
self.mean = mean
|
1789 |
+
self.std = std
|
1790 |
+
|
1791 |
+
def __call__(self, image):
|
1792 |
+
image2 = torch.clone(image)
|
1793 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
1794 |
+
t.mul_(s).add_(m)
|
1795 |
+
return image2
|
1796 |
+
|
1797 |
+
|
1798 |
+
norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
1799 |
+
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
1800 |
+
|
1801 |
+
|
1802 |
+
class TorchPCA(object):
|
1803 |
+
|
1804 |
+
def __init__(self, n_components):
|
1805 |
+
self.n_components = n_components
|
1806 |
+
|
1807 |
+
def fit(self, X):
|
1808 |
+
self.mean_ = X.mean(dim=0)
|
1809 |
+
unbiased = X - self.mean_.unsqueeze(0)
|
1810 |
+
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
|
1811 |
+
self.components_ = V.T
|
1812 |
+
self.singular_values_ = S
|
1813 |
+
return self
|
1814 |
+
|
1815 |
+
def transform(self, X):
|
1816 |
+
t0 = X - self.mean_.unsqueeze(0)
|
1817 |
+
projected = t0 @ self.components_.T
|
1818 |
+
return projected
|
1819 |
+
|
1820 |
+
|
1821 |
+
def pca(image_feats_list, dim=3, fit_pca=None):
|
1822 |
+
# from sklearn.decomposition import PCA
|
1823 |
+
|
1824 |
+
device = image_feats_list[0].device
|
1825 |
+
|
1826 |
+
def flatten(tensor, target_size=None):
|
1827 |
+
if target_size is not None and fit_pca is None:
|
1828 |
+
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
|
1829 |
+
B, C, H, W = tensor.shape
|
1830 |
+
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
|
1831 |
+
|
1832 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
1833 |
+
target_size = image_feats_list[0].shape[2]
|
1834 |
+
else:
|
1835 |
+
target_size = None
|
1836 |
+
|
1837 |
+
flattened_feats = []
|
1838 |
+
for feats in image_feats_list:
|
1839 |
+
flattened_feats.append(flatten(feats, target_size))
|
1840 |
+
x = torch.cat(flattened_feats, dim=0)
|
1841 |
+
|
1842 |
+
if fit_pca is None:
|
1843 |
+
# fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy()))
|
1844 |
+
fit_pca = TorchPCA(n_components=dim).fit(x)
|
1845 |
+
|
1846 |
+
reduced_feats = []
|
1847 |
+
for feats in image_feats_list:
|
1848 |
+
# x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
|
1849 |
+
x_red = fit_pca.transform(flatten(feats))
|
1850 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
1851 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
1852 |
+
B, C, H, W = feats.shape
|
1853 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
|
1854 |
+
|
1855 |
+
return reduced_feats, fit_pca
|
1856 |
+
|
1857 |
+
|
1858 |
+
def my_load_audio(audio_file):
|
1859 |
+
loaded_waveform, obs_sr = torchaudio.load(audio_file)
|
1860 |
+
loaded_waveform = loaded_waveform[0]
|
1861 |
+
|
1862 |
+
neg_waveform, neg_obs_sr = None, None
|
1863 |
+
from data.AVDatasets import prep_waveform
|
1864 |
+
|
1865 |
+
(waveform,
|
1866 |
+
spectrogram,
|
1867 |
+
audio_length,
|
1868 |
+
total_length,
|
1869 |
+
original_length,
|
1870 |
+
mask,
|
1871 |
+
pos_mask) = prep_waveform(
|
1872 |
+
loaded_waveform,
|
1873 |
+
obs_sr,
|
1874 |
+
10,
|
1875 |
+
128,
|
1876 |
+
-4.268,
|
1877 |
+
9.138,
|
1878 |
+
16000,
|
1879 |
+
True,
|
1880 |
+
False,
|
1881 |
+
False,
|
1882 |
+
neg_waveform,
|
1883 |
+
neg_obs_sr,
|
1884 |
+
False,
|
1885 |
+
)
|
1886 |
+
|
1887 |
+
patch_size = 204
|
1888 |
+
n_tiles = spectrogram.shape[1] // patch_size
|
1889 |
+
assert n_tiles == 5
|
1890 |
+
|
1891 |
+
patches = []
|
1892 |
+
for i in range(n_tiles):
|
1893 |
+
patches.append(spectrogram[:, i * patch_size:(i + 1) * patch_size, :])
|
1894 |
+
|
1895 |
+
patches = torch.cat(patches, dim=0).permute(0, 2, 1).unsqueeze(1)
|
1896 |
+
return patches
|
1897 |
+
|
1898 |
+
|
1899 |
+
class ImageBindImageFeaturizer(nn.Module):
|
1900 |
+
|
1901 |
+
def __init__(self, output_path, model=None):
|
1902 |
+
super().__init__()
|
1903 |
+
if model is not None:
|
1904 |
+
self.model = model
|
1905 |
+
else:
|
1906 |
+
self.model = imagebind_huge(output_path, pretrained=True).cuda()
|
1907 |
+
|
1908 |
+
def forward(self, image, include_cls):
|
1909 |
+
inputs = {
|
1910 |
+
ModalityType.VISION: image,
|
1911 |
+
}
|
1912 |
+
|
1913 |
+
patch_tokens, cls_tokens = self.model.forward_features(inputs)[ModalityType.VISION]
|
1914 |
+
patch_tokens = patch_tokens.permute(0, 3, 1, 2)
|
1915 |
+
|
1916 |
+
if include_cls:
|
1917 |
+
return patch_tokens, cls_tokens
|
1918 |
+
else:
|
1919 |
+
return patch_tokens
|
1920 |
+
|
1921 |
+
|
1922 |
+
class ImageBindAudioFeaturizer(nn.Module):
|
1923 |
+
|
1924 |
+
def __init__(self, output_path, model=None):
|
1925 |
+
super().__init__()
|
1926 |
+
if model is not None:
|
1927 |
+
self.model = model
|
1928 |
+
else:
|
1929 |
+
self.model = imagebind_huge(output_path, pretrained=True).cuda()
|
1930 |
+
|
1931 |
+
def forward(self, spec, include_cls):
|
1932 |
+
|
1933 |
+
patch_size = 204
|
1934 |
+
n_tiles = spec.shape[2] // patch_size
|
1935 |
+
assert n_tiles == 5
|
1936 |
+
|
1937 |
+
patches = []
|
1938 |
+
for i in range(n_tiles):
|
1939 |
+
patches.append(spec[:, :, i * patch_size:(i + 1) * patch_size, :])
|
1940 |
+
|
1941 |
+
patches = torch.cat(patches, dim=1).permute(0, 1, 3, 2).unsqueeze(2)
|
1942 |
+
|
1943 |
+
inputs = {
|
1944 |
+
ModalityType.AUDIO: patches,
|
1945 |
+
}
|
1946 |
+
|
1947 |
+
patch_tokens, cls_token = self.model.forward_features(inputs)[ModalityType.AUDIO]
|
1948 |
+
|
1949 |
+
patch_tokens = patch_tokens.permute(0, 4, 2, 1, 3)
|
1950 |
+
b, c, h, p, w = patch_tokens.shape
|
1951 |
+
patch_tokens = patch_tokens.reshape(b, c, h, w * p)
|
1952 |
+
|
1953 |
+
cls_token = cls_token.reshape(b, p, -1).mean(1)
|
1954 |
+
|
1955 |
+
if include_cls:
|
1956 |
+
return patch_tokens, cls_token
|
1957 |
+
else:
|
1958 |
+
return patch_tokens
|
1959 |
+
|
1960 |
+
|
1961 |
+
if __name__ == "__main__":
|
1962 |
+
image_paths = ["../../samples/dog_image.jpg", "../../samples/car_image.jpg", "../../samples/bird_image.jpg"]
|
1963 |
+
audio_paths = ["../../samples/dog_audio.wav", "../../samples/car_audio.wav", "../../samples/bird_audio.wav"]
|
1964 |
+
|
1965 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1966 |
+
|
1967 |
+
# Instantiate model
|
1968 |
+
model = imagebind_huge("../../", pretrained=True)
|
1969 |
+
model.eval()
|
1970 |
+
model.to(device)
|
1971 |
+
|
1972 |
+
audio_inputs = torch.cat([my_load_audio(af).unsqueeze(0) for af in audio_paths], dim=0).cuda()
|
1973 |
+
# Load data
|
1974 |
+
inputs = {
|
1975 |
+
ModalityType.VISION: load_and_transform_vision_data(image_paths, device),
|
1976 |
+
# ModalityType.AUDIO: load_and_transform_audio_data(audio_paths, device, clip_duration=2, clips_per_video=5),
|
1977 |
+
ModalityType.AUDIO: audio_inputs,
|
1978 |
+
|
1979 |
+
}
|
1980 |
+
|
1981 |
+
with torch.no_grad():
|
1982 |
+
embeddings = model.forward_features(inputs)
|
1983 |
+
cls_tokens = model.forward(inputs)
|
1984 |
+
|
1985 |
+
audio_cls_token = embeddings["audio"][1].reshape(3, 5, -1).mean(1)
|
1986 |
+
|
1987 |
+
sims1 = torch.einsum(
|
1988 |
+
"bc,dc->bd",
|
1989 |
+
embeddings["vision"][1],
|
1990 |
+
audio_cls_token)
|
1991 |
+
|
1992 |
+
print(torch.softmax(sims1, dim=1).cpu().numpy())
|
1993 |
+
#
|
1994 |
+
# sims2 = torch.einsum(
|
1995 |
+
# "bc,dc->bd",
|
1996 |
+
# embeddings["vision"].mean(1).mean(1),
|
1997 |
+
# embeddings["audio"].mean(1).mean(1).mean(1)
|
1998 |
+
# )
|
1999 |
+
#
|
2000 |
+
# print(torch.softmax(sims2, dim=1).cpu().numpy())
|
2001 |
+
#
|
2002 |
+
#
|
2003 |
+
# img_num = 0
|
2004 |
+
# img_feats = F.normalize(embeddings["vision"].permute(0, 3, 1, 2), dim=1)
|
2005 |
+
# [red_img_feats], fit_pca = pca([img_feats])
|
2006 |
+
#
|
2007 |
+
# fig, axes = plt.subplots(2, 2, figsize=(4 * 2, 4 * 2))
|
2008 |
+
# axes[0][0].imshow(unnorm(inputs["vision"][0].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
|
2009 |
+
# axes[0][1].imshow(unnorm(inputs["vision"][1].unsqueeze(0))[0].permute(1, 2, 0).detach().cpu())
|
2010 |
+
# axes[1][0].imshow(red_img_feats[0].permute(1, 2, 0).detach().cpu())
|
2011 |
+
# axes[1][1].imshow(red_img_feats[1].permute(1, 2, 0).detach().cpu())
|
2012 |
+
# plt.tight_layout()
|
2013 |
+
# plt.show()
|
2014 |
+
#
|
2015 |
+
audio_embs = F.normalize(embeddings["audio"][0], dim=-1)
|
2016 |
+
b, n, h, w, c = audio_embs.shape
|
2017 |
+
|
2018 |
+
audio_embs = audio_embs.permute(0, 4, 2, 1, 3).reshape(b, c, h, w * n)
|
2019 |
+
|
2020 |
+
b, n, c, h, w = inputs["audio"].shape
|
2021 |
+
audio_inputs = inputs["audio"].permute(0, 2, 3, 1, 4).reshape(b, c, h, w * n)
|
2022 |
+
|
2023 |
+
print("here")
|
2024 |
+
|
2025 |
+
for img_num in range(3):
|
2026 |
+
[red_audio], fit_pca = pca([audio_embs[img_num].unsqueeze(0)])
|
2027 |
+
fig, axes = plt.subplots(2, 1, figsize=(10 * 1, 4 * 2))
|
2028 |
+
axes[0].imshow(audio_inputs[img_num, 0].detach().cpu())
|
2029 |
+
axes[1].imshow(red_audio[0].permute(1, 2, 0).detach().cpu())
|
2030 |
+
plt.tight_layout()
|
2031 |
+
plt.show()
|
2032 |
+
|
2033 |
+
print("here")
|
DenseAV/denseav/featurizers/__init__.py
ADDED
File without changes
|
DenseAV/denseav/plotting.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict
|
3 |
+
|
4 |
+
import matplotlib.colors as mcolors
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import numpy as np
|
7 |
+
import scipy.io.wavfile as wavfile
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torchvision
|
11 |
+
from moviepy.editor import VideoFileClip, AudioFileClip
|
12 |
+
from base64 import b64encode
|
13 |
+
from denseav.shared import pca
|
14 |
+
|
15 |
+
|
16 |
+
def write_video_with_audio(video_frames, audio_array, video_fps, audio_fps, output_path):
|
17 |
+
"""
|
18 |
+
Writes video frames and audio to a specified path.
|
19 |
+
|
20 |
+
Parameters:
|
21 |
+
- video_frames: torch.Tensor of shape (num_frames, height, width, channels)
|
22 |
+
- audio_array: torch.Tensor of shape (num_samples, num_channels)
|
23 |
+
- video_fps: int, frames per second of the video
|
24 |
+
- audio_fps: int, sample rate of the audio
|
25 |
+
- output_path: str, path to save the final video with audio
|
26 |
+
"""
|
27 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
28 |
+
|
29 |
+
temp_video_path = output_path.replace('.mp4', '_temp.mp4')
|
30 |
+
temp_audio_path = output_path.replace('.mp4', '_temp_audio.wav')
|
31 |
+
video_options = {
|
32 |
+
'crf': '23',
|
33 |
+
'preset': 'slow',
|
34 |
+
'bit_rate': '1000k'}
|
35 |
+
|
36 |
+
if audio_array is not None:
|
37 |
+
torchvision.io.write_video(
|
38 |
+
filename=temp_video_path,
|
39 |
+
video_array=video_frames,
|
40 |
+
fps=video_fps,
|
41 |
+
options=video_options
|
42 |
+
)
|
43 |
+
|
44 |
+
wavfile.write(temp_audio_path, audio_fps, audio_array.cpu().to(torch.float64).permute(1, 0).numpy())
|
45 |
+
video_clip = VideoFileClip(temp_video_path)
|
46 |
+
audio_clip = AudioFileClip(temp_audio_path)
|
47 |
+
final_clip = video_clip.set_audio(audio_clip)
|
48 |
+
final_clip.write_videofile(output_path, codec='libx264', verbose=False)
|
49 |
+
os.remove(temp_video_path)
|
50 |
+
os.remove(temp_audio_path)
|
51 |
+
else:
|
52 |
+
torchvision.io.write_video(
|
53 |
+
filename=output_path,
|
54 |
+
video_array=video_frames,
|
55 |
+
fps=video_fps,
|
56 |
+
options=video_options
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def alpha_blend_layers(layers):
|
61 |
+
blended_image = layers[0]
|
62 |
+
for layer in layers[1:]:
|
63 |
+
rgb1, alpha1 = blended_image[:, :3, :, :], blended_image[:, 3:4, :, :]
|
64 |
+
rgb2, alpha2 = layer[:, :3, :, :], layer[:, 3:4, :, :]
|
65 |
+
alpha_out = alpha2 + alpha1 * (1 - alpha2)
|
66 |
+
rgb_out = (rgb2 * alpha2 + rgb1 * alpha1 * (1 - alpha2)) / alpha_out.clamp(min=1e-7)
|
67 |
+
blended_image = torch.cat([rgb_out, alpha_out], dim=1)
|
68 |
+
return (blended_image[:, :3] * 255).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1)
|
69 |
+
|
70 |
+
|
71 |
+
def _prep_sims_for_plotting(sim_by_head, frames):
|
72 |
+
with torch.no_grad():
|
73 |
+
results = defaultdict(list)
|
74 |
+
n_frames, _, vh, vw = frames.shape
|
75 |
+
|
76 |
+
sims = sim_by_head.max(dim=1).values
|
77 |
+
|
78 |
+
n_audio_feats = sims.shape[-1]
|
79 |
+
for frame_num in range(n_frames):
|
80 |
+
selected_audio_feat = int((frame_num / n_frames) * n_audio_feats)
|
81 |
+
|
82 |
+
selected_sim = F.interpolate(
|
83 |
+
sims[frame_num, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
|
84 |
+
size=(vh, vw),
|
85 |
+
mode="bicubic")
|
86 |
+
|
87 |
+
results["sims_all"].append(selected_sim)
|
88 |
+
|
89 |
+
for head in range(sim_by_head.shape[1]):
|
90 |
+
selected_sim = F.interpolate(
|
91 |
+
sim_by_head[frame_num, head, :, :, selected_audio_feat].unsqueeze(0).unsqueeze(0),
|
92 |
+
size=(vh, vw),
|
93 |
+
mode="bicubic")
|
94 |
+
results[f"sims_{head + 1}"].append(selected_sim)
|
95 |
+
|
96 |
+
results = {k: torch.cat(v, dim=0) for k, v in results.items()}
|
97 |
+
return results
|
98 |
+
|
99 |
+
|
100 |
+
def get_plasma_with_alpha():
|
101 |
+
plasma = plt.cm.plasma(np.linspace(0, 1, 256))
|
102 |
+
alphas = np.linspace(0, 1, 256)
|
103 |
+
plasma_with_alpha = np.zeros((256, 4))
|
104 |
+
plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
|
105 |
+
plasma_with_alpha[:, 3] = alphas
|
106 |
+
return mcolors.ListedColormap(plasma_with_alpha)
|
107 |
+
|
108 |
+
|
109 |
+
def get_inferno_with_alpha_2(alpha=0.5, k=30):
|
110 |
+
k_fraction = k / 100.0
|
111 |
+
custom_cmap = np.zeros((256, 4))
|
112 |
+
threshold_index = int(k_fraction * 256)
|
113 |
+
custom_cmap[:threshold_index, :3] = 0 # RGB values for black
|
114 |
+
custom_cmap[:threshold_index, 3] = alpha # Alpha value
|
115 |
+
remaining_inferno = plt.cm.inferno(np.linspace(0, 1, 256 - threshold_index))
|
116 |
+
custom_cmap[threshold_index:, :3] = remaining_inferno[:, :3]
|
117 |
+
custom_cmap[threshold_index:, 3] = alpha # Alpha value
|
118 |
+
return mcolors.ListedColormap(custom_cmap)
|
119 |
+
|
120 |
+
|
121 |
+
def get_inferno_with_alpha():
|
122 |
+
plasma = plt.cm.inferno(np.linspace(0, 1, 256))
|
123 |
+
alphas = np.linspace(0, 1, 256)
|
124 |
+
plasma_with_alpha = np.zeros((256, 4))
|
125 |
+
plasma_with_alpha[:, 0:3] = plasma[:, 0:3]
|
126 |
+
plasma_with_alpha[:, 3] = alphas
|
127 |
+
return mcolors.ListedColormap(plasma_with_alpha)
|
128 |
+
|
129 |
+
|
130 |
+
red_cmap = mcolors.LinearSegmentedColormap('RedMap', segmentdata={
|
131 |
+
'red': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
|
132 |
+
'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
|
133 |
+
'blue': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
|
134 |
+
'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
|
135 |
+
})
|
136 |
+
|
137 |
+
blue_cmap = mcolors.LinearSegmentedColormap('BlueMap', segmentdata={
|
138 |
+
'red': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
|
139 |
+
'green': [(0.0, 0.0, 0.0), (1.0, 0.0, 0.0)],
|
140 |
+
'blue': [(0.0, 1.0, 1.0), (1.0, 1.0, 1.0)],
|
141 |
+
'alpha': [(0.0, 0.0, 0.0), (1.0, 1.0, 1.0)]
|
142 |
+
})
|
143 |
+
|
144 |
+
|
145 |
+
def plot_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
|
146 |
+
prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
|
147 |
+
n_frames, _, vh, vw = frames.shape
|
148 |
+
sims_all = prepped_sims["sims_all"].clamp_min(0)
|
149 |
+
sims_all -= sims_all.min()
|
150 |
+
sims_all = sims_all / sims_all.max()
|
151 |
+
cmap = get_inferno_with_alpha()
|
152 |
+
layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
|
153 |
+
layer2 = torch.tensor(cmap(sims_all.squeeze().detach().cpu())).permute(0, 3, 1, 2)
|
154 |
+
write_video_with_audio(
|
155 |
+
alpha_blend_layers([layer1, layer2]),
|
156 |
+
audio,
|
157 |
+
video_fps,
|
158 |
+
audio_fps,
|
159 |
+
output_filename)
|
160 |
+
|
161 |
+
|
162 |
+
def plot_2head_attention_video(sims_by_head, frames, audio, video_fps, audio_fps, output_filename):
|
163 |
+
prepped_sims = _prep_sims_for_plotting(sims_by_head, frames)
|
164 |
+
sims_1 = prepped_sims["sims_1"]
|
165 |
+
sims_2 = prepped_sims["sims_2"]
|
166 |
+
|
167 |
+
n_frames, _, vh, vw = frames.shape
|
168 |
+
|
169 |
+
mask = sims_1 > sims_2
|
170 |
+
sims_1 *= mask
|
171 |
+
sims_2 *= (~mask)
|
172 |
+
|
173 |
+
sims_1 = sims_1.clamp_min(0)
|
174 |
+
sims_1 -= sims_1.min()
|
175 |
+
sims_1 = sims_1 / sims_1.max()
|
176 |
+
|
177 |
+
sims_2 = sims_2.clamp_min(0)
|
178 |
+
sims_2 -= sims_2.min()
|
179 |
+
sims_2 = sims_2 / sims_2.max()
|
180 |
+
|
181 |
+
layer1 = torch.cat([frames, torch.ones(n_frames, 1, vh, vw)], axis=1)
|
182 |
+
layer2_head1 = torch.tensor(red_cmap(sims_1.squeeze().detach().cpu())).permute(0, 3, 1, 2)
|
183 |
+
layer2_head2 = torch.tensor(blue_cmap(sims_2.squeeze().detach().cpu())).permute(0, 3, 1, 2)
|
184 |
+
|
185 |
+
write_video_with_audio(
|
186 |
+
alpha_blend_layers([layer1, layer2_head1, layer2_head2]),
|
187 |
+
audio,
|
188 |
+
video_fps,
|
189 |
+
audio_fps,
|
190 |
+
output_filename)
|
191 |
+
|
192 |
+
|
193 |
+
def plot_feature_video(image_feats,
|
194 |
+
audio_feats,
|
195 |
+
frames,
|
196 |
+
audio,
|
197 |
+
video_fps,
|
198 |
+
audio_fps,
|
199 |
+
video_filename,
|
200 |
+
audio_filename):
|
201 |
+
with torch.no_grad():
|
202 |
+
image_feats_ = image_feats.cpu()
|
203 |
+
audio_feats_ = audio_feats.cpu()
|
204 |
+
[red_img_feats, red_audio_feats], _ = pca([
|
205 |
+
image_feats_,
|
206 |
+
audio_feats_, # .tile(image_feats_.shape[0], 1, 1, 1)
|
207 |
+
])
|
208 |
+
_, _, vh, vw = frames.shape
|
209 |
+
red_img_feats = F.interpolate(red_img_feats, size=(vh, vw), mode="bicubic")
|
210 |
+
red_audio_feats = red_audio_feats[0].unsqueeze(0)
|
211 |
+
red_audio_feats = F.interpolate(red_audio_feats, size=(50, red_img_feats.shape[0]), mode="bicubic")
|
212 |
+
|
213 |
+
write_video_with_audio(
|
214 |
+
(red_img_feats.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
|
215 |
+
audio,
|
216 |
+
video_fps,
|
217 |
+
audio_fps,
|
218 |
+
video_filename)
|
219 |
+
|
220 |
+
red_audio_feats_expanded = red_audio_feats.tile(red_img_feats.shape[0], 1, 1, 1)
|
221 |
+
red_audio_feats_expanded = F.interpolate(red_audio_feats_expanded, scale_factor=6, mode="bicubic")
|
222 |
+
for i in range(red_img_feats.shape[0]):
|
223 |
+
center_index = i * 6
|
224 |
+
min_index = max(center_index - 2, 0)
|
225 |
+
max_index = min(center_index + 2, red_audio_feats_expanded.shape[-1])
|
226 |
+
red_audio_feats_expanded[i, :, :, min_index:max_index] = 1
|
227 |
+
|
228 |
+
write_video_with_audio(
|
229 |
+
(red_audio_feats_expanded.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8),
|
230 |
+
audio,
|
231 |
+
video_fps,
|
232 |
+
audio_fps,
|
233 |
+
audio_filename)
|
234 |
+
|
235 |
+
|
236 |
+
def display_video_in_notebook(path):
|
237 |
+
from IPython.display import HTML, display
|
238 |
+
mp4 = open(path, 'rb').read()
|
239 |
+
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
|
240 |
+
display(HTML("""
|
241 |
+
<video width=400 controls>
|
242 |
+
<source src="%s" type="video/mp4">
|
243 |
+
</video>
|
244 |
+
""" % data_url))
|
DenseAV/denseav/saved_models.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
from os.path import join
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest(name, checkpoint_dir, extra_args=None):
|
10 |
+
if extra_args is None:
|
11 |
+
extra_args = dict()
|
12 |
+
files = os.listdir(join(checkpoint_dir, name))
|
13 |
+
steps = torch.tensor([int(f.split("step=")[-1].split(".")[0]) for f in files])
|
14 |
+
selected = files[steps.argmax()]
|
15 |
+
return dict(
|
16 |
+
chkpt_name=os.path.join(name, selected),
|
17 |
+
extra_args=extra_args)
|
18 |
+
|
19 |
+
|
20 |
+
DS_PARAM_REGEX = r'_forward_module\.(.+)'
|
21 |
+
|
22 |
+
|
23 |
+
def convert_deepspeed_checkpoint(deepspeed_ckpt_path: str, pl_ckpt_path: str = None):
|
24 |
+
'''
|
25 |
+
Creates a PyTorch Lightning checkpoint from the DeepSpeed checkpoint directory, while patching
|
26 |
+
in parameters which are improperly loaded by the DeepSpeed conversion utility.
|
27 |
+
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
|
28 |
+
pl_ckpt_path: Path to the reconstructed PyTorch Lightning checkpoint. If not specified, will be
|
29 |
+
placed in the same directory as the DeepSpeed checkpoint directory with the same name but
|
30 |
+
a .pt extension.
|
31 |
+
Returns: path to the converted checkpoint.
|
32 |
+
'''
|
33 |
+
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
|
34 |
+
|
35 |
+
|
36 |
+
if not (deepspeed_ckpt_path.endswith('.ckpt') and os.path.isdir(deepspeed_ckpt_path)):
|
37 |
+
raise ValueError(
|
38 |
+
'args.ckpt_dir should point to the checkpoint directory'
|
39 |
+
' output by DeepSpeed (e.g. "last.ckpt" or "epoch=4-step=39150.ckpt").'
|
40 |
+
)
|
41 |
+
|
42 |
+
# Convert state dict to PyTorch format
|
43 |
+
if not pl_ckpt_path:
|
44 |
+
pl_ckpt_path = f'{deepspeed_ckpt_path[:-4]}pt' # .ckpt --> .pt
|
45 |
+
|
46 |
+
if not os.path.exists(pl_ckpt_path):
|
47 |
+
convert_zero_checkpoint_to_fp32_state_dict(deepspeed_ckpt_path, pl_ckpt_path)
|
48 |
+
|
49 |
+
# Patch in missing parameters that failed to be converted by DeepSpeed utility
|
50 |
+
pl_ckpt = _merge_deepspeed_weights(deepspeed_ckpt_path, pl_ckpt_path)
|
51 |
+
torch.save(pl_ckpt, pl_ckpt_path)
|
52 |
+
|
53 |
+
return pl_ckpt_path
|
54 |
+
|
55 |
+
|
56 |
+
def get_optim_files(checkpoint_dir):
|
57 |
+
files = sorted([f for f in os.listdir(checkpoint_dir) if "optim" in f])
|
58 |
+
return [join(checkpoint_dir, f) for f in files]
|
59 |
+
|
60 |
+
|
61 |
+
def get_model_state_file(checkpoint_dir, zero_stage):
|
62 |
+
f = [f for f in os.listdir(checkpoint_dir) if "model_states" in f][0]
|
63 |
+
return join(checkpoint_dir, f)
|
64 |
+
|
65 |
+
|
66 |
+
def _merge_deepspeed_weights(deepspeed_ckpt_path: str, fp32_ckpt_path: str):
|
67 |
+
'''
|
68 |
+
Merges tensors with keys in the DeepSpeed checkpoint but not in the fp32_checkpoint
|
69 |
+
into the fp32 state dict.
|
70 |
+
deepspeed_ckpt_path: Path to the DeepSpeed checkpoint folder.
|
71 |
+
fp32_ckpt_path: Path to the reconstructed
|
72 |
+
'''
|
73 |
+
from pytorch_lightning.utilities.deepspeed import ds_checkpoint_dir
|
74 |
+
|
75 |
+
|
76 |
+
# This first part is based on pytorch_lightning.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict
|
77 |
+
checkpoint_dir = ds_checkpoint_dir(deepspeed_ckpt_path)
|
78 |
+
optim_files = get_optim_files(checkpoint_dir)
|
79 |
+
optim_state = torch.load(optim_files[0], map_location='cpu')
|
80 |
+
zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
|
81 |
+
deepspeed_model_file = get_model_state_file(checkpoint_dir, zero_stage)
|
82 |
+
|
83 |
+
# Start adding all parameters from DeepSpeed ckpt to generated PyTorch Lightning ckpt
|
84 |
+
ds_ckpt = torch.load(deepspeed_model_file, map_location='cpu')
|
85 |
+
ds_sd = ds_ckpt['module']
|
86 |
+
|
87 |
+
fp32_ckpt = torch.load(fp32_ckpt_path, map_location='cpu')
|
88 |
+
fp32_sd = fp32_ckpt['state_dict']
|
89 |
+
|
90 |
+
for k, v in ds_sd.items():
|
91 |
+
try:
|
92 |
+
match = re.match(DS_PARAM_REGEX, k)
|
93 |
+
param_name = match.group(1)
|
94 |
+
except:
|
95 |
+
print(f'Failed to extract parameter from DeepSpeed key {k}')
|
96 |
+
continue
|
97 |
+
|
98 |
+
v = v.to(torch.float32)
|
99 |
+
if param_name not in fp32_sd:
|
100 |
+
print(f'Adding parameter {param_name} from DeepSpeed state_dict to fp32_sd')
|
101 |
+
fp32_sd[param_name] = v
|
102 |
+
else:
|
103 |
+
assert torch.allclose(v, fp32_sd[param_name].to(torch.float32), atol=1e-2)
|
104 |
+
|
105 |
+
return fp32_ckpt
|
106 |
+
|
107 |
+
|
108 |
+
def get_version_and_step(f, i):
|
109 |
+
step = f.split("step=")[-1].split(".")[0]
|
110 |
+
if "-v" in step:
|
111 |
+
[step, version] = step.split("-v")
|
112 |
+
else:
|
113 |
+
step, version = step, 0
|
114 |
+
|
115 |
+
return int(version), int(step), i
|
116 |
+
|
117 |
+
|
118 |
+
def get_latest_ds(name, extra_args=None):
|
119 |
+
if extra_args is None:
|
120 |
+
extra_args = dict()
|
121 |
+
files = os.listdir(f"../checkpoints/{name}")
|
122 |
+
latest = sorted([get_version_and_step(f, i) for i, f in enumerate(files)], reverse=True)[0]
|
123 |
+
selected = files[latest[-1]]
|
124 |
+
# print(f"Selecting file: {selected}")
|
125 |
+
ds_chkpt = join(name, selected)
|
126 |
+
reg_chkpt = join(name + "_fp32", selected)
|
127 |
+
reg_chkpt_path = join("../checkpoints", reg_chkpt)
|
128 |
+
if not os.path.exists(reg_chkpt_path):
|
129 |
+
os.makedirs(os.path.dirname(reg_chkpt_path), exist_ok=True)
|
130 |
+
print(f"Checkpoint {reg_chkpt} does not exist, converting from deepspeed")
|
131 |
+
convert_deepspeed_checkpoint(join("../checkpoints", ds_chkpt), reg_chkpt_path)
|
132 |
+
return dict(
|
133 |
+
chkpt_name=reg_chkpt,
|
134 |
+
extra_args=extra_args)
|
135 |
+
|
136 |
+
|
137 |
+
def get_all_models_in_dir(name, checkpoint_dir, extra_args=None):
|
138 |
+
ret = {}
|
139 |
+
for model_dir in os.listdir(join(checkpoint_dir, name)):
|
140 |
+
full_name = f"{name}/{model_dir}/train"
|
141 |
+
# print(f'"{full_name}",')
|
142 |
+
ret[full_name] = get_latest(full_name, checkpoint_dir, extra_args)
|
143 |
+
return ret
|
144 |
+
|
145 |
+
|
146 |
+
def saved_model_dict(checkpoint_dir):
|
147 |
+
model_info = {
|
148 |
+
|
149 |
+
**get_all_models_in_dir(
|
150 |
+
"9-5-23-mixed",
|
151 |
+
checkpoint_dir,
|
152 |
+
extra_args=dict(
|
153 |
+
mixup_weight=0.0,
|
154 |
+
sim_use_cls=False,
|
155 |
+
audio_pool_width=1,
|
156 |
+
memory_buffer_size=0,
|
157 |
+
loss_leak=0.0)
|
158 |
+
),
|
159 |
+
|
160 |
+
**get_all_models_in_dir(
|
161 |
+
"1-23-24-rebuttal-heads",
|
162 |
+
checkpoint_dir,
|
163 |
+
extra_args=dict(
|
164 |
+
loss_leak=0.0)
|
165 |
+
),
|
166 |
+
|
167 |
+
**get_all_models_in_dir(
|
168 |
+
"11-8-23",
|
169 |
+
checkpoint_dir,
|
170 |
+
extra_args=dict(loss_leak=0.0)),
|
171 |
+
|
172 |
+
**get_all_models_in_dir(
|
173 |
+
"10-30-23-3",
|
174 |
+
checkpoint_dir,
|
175 |
+
extra_args=dict(loss_leak=0.0)),
|
176 |
+
|
177 |
+
"davenet": dict(
|
178 |
+
chkpt_name=None,
|
179 |
+
extra_args=dict(
|
180 |
+
audio_blur=1,
|
181 |
+
image_model_type="davenet",
|
182 |
+
image_aligner_type=None,
|
183 |
+
audio_model_type="davenet",
|
184 |
+
audio_aligner_type=None,
|
185 |
+
audio_input="davenet_spec",
|
186 |
+
use_cached_embs=False,
|
187 |
+
dropout=False,
|
188 |
+
sim_agg_heads=1,
|
189 |
+
nonneg_sim=False,
|
190 |
+
audio_lora=False,
|
191 |
+
image_lora=False,
|
192 |
+
norm_vectors=False,
|
193 |
+
),
|
194 |
+
data_args=dict(
|
195 |
+
use_cached_embs=False,
|
196 |
+
use_davenet_spec=True,
|
197 |
+
override_target_length=20,
|
198 |
+
audio_model_type="davenet",
|
199 |
+
),
|
200 |
+
),
|
201 |
+
|
202 |
+
"cavmae": dict(
|
203 |
+
chkpt_name=None,
|
204 |
+
extra_args=dict(
|
205 |
+
audio_blur=1,
|
206 |
+
image_model_type="cavmae",
|
207 |
+
image_aligner_type=None,
|
208 |
+
audio_model_type="cavmae",
|
209 |
+
audio_aligner_type=None,
|
210 |
+
audio_input="spec",
|
211 |
+
use_cached_embs=False,
|
212 |
+
sim_agg_heads=1,
|
213 |
+
dropout=False,
|
214 |
+
nonneg_sim=False,
|
215 |
+
audio_lora=False,
|
216 |
+
image_lora=False,
|
217 |
+
norm_vectors=False,
|
218 |
+
learn_audio_cls=False,
|
219 |
+
sim_agg_type="cavmae",
|
220 |
+
),
|
221 |
+
data_args=dict(
|
222 |
+
use_cached_embs=False,
|
223 |
+
use_davenet_spec=True,
|
224 |
+
audio_model_type="cavmae",
|
225 |
+
override_target_length=10,
|
226 |
+
),
|
227 |
+
),
|
228 |
+
|
229 |
+
"imagebind": dict(
|
230 |
+
chkpt_name=None,
|
231 |
+
extra_args=dict(
|
232 |
+
audio_blur=1,
|
233 |
+
image_model_type="imagebind",
|
234 |
+
image_aligner_type=None,
|
235 |
+
audio_model_type="imagebind",
|
236 |
+
audio_aligner_type=None,
|
237 |
+
audio_input="spec",
|
238 |
+
use_cached_embs=False,
|
239 |
+
sim_agg_heads=1,
|
240 |
+
dropout=False,
|
241 |
+
nonneg_sim=False,
|
242 |
+
audio_lora=False,
|
243 |
+
image_lora=False,
|
244 |
+
norm_vectors=False,
|
245 |
+
learn_audio_cls=False,
|
246 |
+
sim_agg_type="imagebind",
|
247 |
+
),
|
248 |
+
data_args=dict(
|
249 |
+
use_cached_embs=False,
|
250 |
+
use_davenet_spec=True,
|
251 |
+
audio_model_type="imagebind",
|
252 |
+
override_target_length=10,
|
253 |
+
),
|
254 |
+
),
|
255 |
+
|
256 |
+
}
|
257 |
+
|
258 |
+
model_info["denseav_language"] = model_info["10-30-23-3/places_base/train"]
|
259 |
+
model_info["denseav_sound"] = model_info["11-8-23/hubert_1h_asf_cls_full_image_train_small_lr/train"]
|
260 |
+
model_info["denseav_2head"] = model_info["1-23-24-rebuttal-heads/mixed-2h/train"]
|
261 |
+
|
262 |
+
return model_info
|
DenseAV/denseav/shared.py
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from collections import defaultdict, deque
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
import math
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.distributed as dist
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchaudio
|
12 |
+
import torchvision.transforms as T
|
13 |
+
from PIL import Image
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
from torchaudio.functional import resample
|
16 |
+
|
17 |
+
|
18 |
+
class UnNormalize(object):
|
19 |
+
def __init__(self, mean, std):
|
20 |
+
self.mean = mean
|
21 |
+
self.std = std
|
22 |
+
|
23 |
+
def __call__(self, image):
|
24 |
+
image2 = torch.clone(image)
|
25 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
26 |
+
t.mul_(s).add_(m)
|
27 |
+
return image2
|
28 |
+
|
29 |
+
|
30 |
+
class SliceDataset(Dataset):
|
31 |
+
|
32 |
+
def __init__(self, ds, start, end):
|
33 |
+
self.ds = ds
|
34 |
+
self.start = start
|
35 |
+
self.end = end
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return self.end - self.start
|
39 |
+
|
40 |
+
def __getitem__(self, item):
|
41 |
+
return self.ds[item + self.start]
|
42 |
+
|
43 |
+
|
44 |
+
class SubsetDataset(Dataset):
|
45 |
+
|
46 |
+
def __init__(self, ds, subset):
|
47 |
+
self.ds = ds
|
48 |
+
self.subset = subset
|
49 |
+
|
50 |
+
def __len__(self):
|
51 |
+
return len(self.subset)
|
52 |
+
|
53 |
+
def __getitem__(self, item):
|
54 |
+
return self.ds[self.subset[item]]
|
55 |
+
|
56 |
+
|
57 |
+
norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
58 |
+
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
59 |
+
|
60 |
+
|
61 |
+
def crop_to_divisor(x, patch_size):
|
62 |
+
if len(x.shape) == 3:
|
63 |
+
C, H, W = x.shape
|
64 |
+
return x[:, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))]
|
65 |
+
elif len(x.shape) == 4:
|
66 |
+
B, C, H, W = x.shape
|
67 |
+
return x[:, :, :(patch_size * (H // patch_size)), :(patch_size * (W // patch_size))]
|
68 |
+
else:
|
69 |
+
raise ValueError("x should have 3 or 4 dimensions")
|
70 |
+
|
71 |
+
|
72 |
+
def _remove_axes(ax):
|
73 |
+
ax.xaxis.set_major_formatter(plt.NullFormatter())
|
74 |
+
ax.yaxis.set_major_formatter(plt.NullFormatter())
|
75 |
+
ax.set_xticks([])
|
76 |
+
ax.set_yticks([])
|
77 |
+
|
78 |
+
|
79 |
+
def remove_axes(axes):
|
80 |
+
if len(axes.shape) == 2:
|
81 |
+
for ax1 in axes:
|
82 |
+
for ax in ax1:
|
83 |
+
_remove_axes(ax)
|
84 |
+
else:
|
85 |
+
for ax in axes:
|
86 |
+
_remove_axes(ax)
|
87 |
+
|
88 |
+
|
89 |
+
def get_image_featurizer(name, token_type="key", **kwargs):
|
90 |
+
name = name.lower()
|
91 |
+
|
92 |
+
if name == "vit":
|
93 |
+
from denseav.featurizers.DINO import DINOFeaturizer
|
94 |
+
patch_size = 16
|
95 |
+
model = DINOFeaturizer("vit_small_patch16_224", patch_size, token_type)
|
96 |
+
dim = 384
|
97 |
+
elif name == "dino16":
|
98 |
+
from denseav.featurizers.DINO import DINOFeaturizer
|
99 |
+
patch_size = 16
|
100 |
+
model = DINOFeaturizer("dino_vits16", patch_size, token_type)
|
101 |
+
dim = 384
|
102 |
+
elif name == "dino8":
|
103 |
+
from denseav.featurizers.DINO import DINOFeaturizer
|
104 |
+
patch_size = 8
|
105 |
+
model = DINOFeaturizer("dino_vits8", patch_size, token_type)
|
106 |
+
dim = 384
|
107 |
+
elif name == "clip":
|
108 |
+
from denseav.featurizers.CLIP import CLIPFeaturizer
|
109 |
+
patch_size = 16
|
110 |
+
model = CLIPFeaturizer()
|
111 |
+
dim = 512
|
112 |
+
elif name == "cavmae":
|
113 |
+
from denseav.featurizers.CAVMAE import CAVMAEImageFeaturizer
|
114 |
+
model = CAVMAEImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
|
115 |
+
dim = 768
|
116 |
+
patch_size = 16
|
117 |
+
elif name == "fnac":
|
118 |
+
from denseav.featurizers.FNACAVL import FNACImageFeaturizer
|
119 |
+
model = FNACImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
|
120 |
+
dim = 512
|
121 |
+
patch_size = 16
|
122 |
+
elif name == "imagebind":
|
123 |
+
from denseav.featurizers.ImageBind import ImageBindImageFeaturizer
|
124 |
+
model = ImageBindImageFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
|
125 |
+
dim = 1024
|
126 |
+
patch_size = 16
|
127 |
+
elif name == "resnet50":
|
128 |
+
from torchvision import models
|
129 |
+
model = models.resnet50(pretrained=True)
|
130 |
+
model = torch.nn.Sequential(*list(model.children())[:-2])
|
131 |
+
patch_size = 1
|
132 |
+
dim = 2048
|
133 |
+
elif name == "davenet":
|
134 |
+
from fdenseav.eaturizers.DAVENet import DavenetImageFeaturizer
|
135 |
+
model = DavenetImageFeaturizer()
|
136 |
+
patch_size = 1
|
137 |
+
dim = 1024
|
138 |
+
elif name == "dinov2":
|
139 |
+
from denseav.featurizers.DINOv2 import DINOv2Featurizer
|
140 |
+
model = DINOv2Featurizer()
|
141 |
+
patch_size = 14
|
142 |
+
dim = 768
|
143 |
+
else:
|
144 |
+
raise ValueError("unknown model: {}".format(name))
|
145 |
+
return model, patch_size, dim
|
146 |
+
|
147 |
+
|
148 |
+
def get_audio_featurizer(name, **kwargs):
|
149 |
+
if name == "davenet":
|
150 |
+
from denseav.featurizers.DAVENet import DavenetAudioFeaturizer
|
151 |
+
model = DavenetAudioFeaturizer()
|
152 |
+
dim = 1024
|
153 |
+
elif name == "dino8":
|
154 |
+
model, _, dim = get_image_featurizer("dino8")
|
155 |
+
elif name == "hubert":
|
156 |
+
from denseav.featurizers.Hubert import Hubert
|
157 |
+
model = Hubert()
|
158 |
+
dim = 1024
|
159 |
+
elif name == "cavmae":
|
160 |
+
from denseav.featurizers.CAVMAE import CAVMAEAudioFeaturizer
|
161 |
+
model = CAVMAEAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
|
162 |
+
dim = 768
|
163 |
+
elif name == "imagebind":
|
164 |
+
from denseav.featurizers.ImageBind import ImageBindAudioFeaturizer
|
165 |
+
model = ImageBindAudioFeaturizer(kwargs["output_root"], model=kwargs.get("model"))
|
166 |
+
dim = 1024
|
167 |
+
elif name == "audiomae":
|
168 |
+
from denseav.featurizers.AudioMAE import AudioMAE
|
169 |
+
model = AudioMAE(kwargs["output_root"], False)
|
170 |
+
dim = 768
|
171 |
+
elif name == "audiomae-finetuned":
|
172 |
+
from denseav.featurizers.AudioMAE import AudioMAE
|
173 |
+
model = AudioMAE(kwargs["output_root"], True)
|
174 |
+
dim = 768
|
175 |
+
else:
|
176 |
+
raise ValueError("Unknown audio model type")
|
177 |
+
|
178 |
+
return model, dim
|
179 |
+
|
180 |
+
|
181 |
+
def load_img(image_path, transform):
|
182 |
+
return transform(Image.open(image_path)).unsqueeze(0)
|
183 |
+
|
184 |
+
|
185 |
+
def pytorch_to_pil(tensor):
|
186 |
+
return Image.fromarray((unnorm(tensor).permute(0, 2, 3, 1).cpu() * 255)
|
187 |
+
.clamp(0, 255).to(torch.uint8).detach().numpy()[0])
|
188 |
+
|
189 |
+
|
190 |
+
def _get_random_window(waveform, mask, min_size, max_size):
|
191 |
+
effective_size = mask.sum().to(torch.int64)
|
192 |
+
if effective_size <= min_size:
|
193 |
+
return waveform, mask
|
194 |
+
else:
|
195 |
+
window_size = min(torch.randint(low=min_size, high=min(effective_size, max_size), size=()), waveform.shape[0])
|
196 |
+
if window_size == waveform.shape[0]:
|
197 |
+
window_start = 0
|
198 |
+
else:
|
199 |
+
window_start = torch.randint(low=0, high=effective_size - window_size, size=())
|
200 |
+
|
201 |
+
new_waveform = torch.zeros_like(waveform)
|
202 |
+
new_mask = torch.zeros_like(mask)
|
203 |
+
new_waveform[window_start:window_start + window_size] = waveform[window_start:window_start + window_size]
|
204 |
+
new_mask[window_start:window_start + window_size] = mask[window_start:window_start + window_size]
|
205 |
+
return new_waveform, new_mask
|
206 |
+
|
207 |
+
|
208 |
+
def _splice_clips(clip1, clip2, loc, easing_size):
|
209 |
+
assert loc >= 0 and loc < len(clip1), "Invalid location"
|
210 |
+
assert easing_size > 0 and easing_size <= len(clip2), "Invalid easing size"
|
211 |
+
|
212 |
+
try:
|
213 |
+
assert loc + clip2.shape[0] < clip1.shape[0]
|
214 |
+
except Exception as e:
|
215 |
+
print(loc, clip2.shape[0], clip1.shape[0])
|
216 |
+
raise e
|
217 |
+
|
218 |
+
# Split clip1 into three parts: before splice, easing region, after splice
|
219 |
+
before_splice = clip1[:loc]
|
220 |
+
after_splice = clip1[loc + clip2.shape[0]:]
|
221 |
+
|
222 |
+
# Compute the fading weights for the easing region
|
223 |
+
# fade_in_weights = torch.cos(torch.linspace(1, 0, easing_size, device=clip1.device))
|
224 |
+
fade_in_weights = 0.5 * (1 + torch.cos(math.pi * torch.linspace(0, 1, easing_size)))
|
225 |
+
fade_out_weights = 1 - fade_in_weights
|
226 |
+
|
227 |
+
clip1_ease = torch.cat([
|
228 |
+
fade_in_weights,
|
229 |
+
torch.zeros(clip2.shape[0] - easing_size * 2),
|
230 |
+
fade_out_weights,
|
231 |
+
])
|
232 |
+
|
233 |
+
mask = torch.cat([torch.ones(loc), clip1_ease, torch.ones(clip1.shape[0] - (loc + clip2.shape[0]))])
|
234 |
+
|
235 |
+
# Apply fading weights to clip1 and clip2 within the easing region
|
236 |
+
splice = clip1_ease * clip1[loc:loc + clip2.shape[0]] + (1 - clip1_ease) * clip2
|
237 |
+
|
238 |
+
# Concatenate all parts back together
|
239 |
+
spliced_clip = torch.cat((before_splice, splice, after_splice))
|
240 |
+
|
241 |
+
return spliced_clip, mask
|
242 |
+
|
243 |
+
|
244 |
+
def _generate_random_subset(waveform, low, high):
|
245 |
+
length = len(waveform)
|
246 |
+
|
247 |
+
# If waveform is smaller than low or has zero length, return unmodified
|
248 |
+
if length < low or length == 0:
|
249 |
+
return waveform
|
250 |
+
|
251 |
+
# Generate random start index within valid range
|
252 |
+
start = random.randint(0, length - low)
|
253 |
+
|
254 |
+
# Generate random subset size within valid range
|
255 |
+
subset_size = random.randint(low, min(high, length - start))
|
256 |
+
|
257 |
+
# Extract the random subset from the waveform
|
258 |
+
subset = waveform[start: start + subset_size]
|
259 |
+
|
260 |
+
return subset
|
261 |
+
|
262 |
+
|
263 |
+
def level_audio(waveform):
|
264 |
+
waveform -= waveform.mean()
|
265 |
+
waveform /= waveform.abs.max().valus.clamp_min(.0001)
|
266 |
+
return waveform
|
267 |
+
|
268 |
+
|
269 |
+
def prep_waveform(waveform,
|
270 |
+
obs_sr,
|
271 |
+
target_length,
|
272 |
+
spec_mel_bins,
|
273 |
+
spec_mean,
|
274 |
+
spec_std,
|
275 |
+
sample_rate,
|
276 |
+
return_spec,
|
277 |
+
random_clip,
|
278 |
+
extra_audio_masking,
|
279 |
+
neg_waveform,
|
280 |
+
neg_obs_sr,
|
281 |
+
audio_level,
|
282 |
+
audio_aug,
|
283 |
+
):
|
284 |
+
if obs_sr != sample_rate:
|
285 |
+
waveform = resample(waveform, obs_sr, sample_rate)
|
286 |
+
if audio_level:
|
287 |
+
waveform = level_audio(waveform)
|
288 |
+
|
289 |
+
if neg_obs_sr is not None and neg_obs_sr != sample_rate:
|
290 |
+
neg_waveform = resample(neg_waveform, neg_obs_sr, sample_rate)
|
291 |
+
if audio_level:
|
292 |
+
neg_waveform = level_audio(neg_waveform)
|
293 |
+
|
294 |
+
if neg_obs_sr is not None: # and random.random() > .5:
|
295 |
+
neg_waveform_clip = _generate_random_subset(neg_waveform, sample_rate, sample_rate * 4)
|
296 |
+
if waveform.shape[0] - neg_waveform_clip.shape[0] > 0:
|
297 |
+
start = random.randint(0, waveform.shape[0] - neg_waveform_clip.shape[0] - 1)
|
298 |
+
easing = max(int(neg_waveform_clip.shape[0] * 1 / 4), sample_rate // 2)
|
299 |
+
easing = min(int(neg_waveform_clip.shape[0] * 1 / 2), easing)
|
300 |
+
waveform, pos_mask = _splice_clips(waveform, neg_waveform_clip, start, easing_size=easing)
|
301 |
+
else:
|
302 |
+
waveform, pos_mask = waveform, torch.ones_like(waveform)
|
303 |
+
else:
|
304 |
+
waveform, pos_mask = waveform, torch.ones_like(waveform)
|
305 |
+
|
306 |
+
mask = torch.ones_like(waveform)
|
307 |
+
original_length = waveform.shape[0]
|
308 |
+
|
309 |
+
if target_length == 10:
|
310 |
+
target_samples = 164200 # Result is 1024 after spec
|
311 |
+
else:
|
312 |
+
target_samples = int(target_length * sample_rate)
|
313 |
+
|
314 |
+
padding = target_samples - original_length
|
315 |
+
|
316 |
+
if padding > 0:
|
317 |
+
p = torch.nn.ZeroPad2d((0, padding))
|
318 |
+
waveform = p(waveform)
|
319 |
+
mask = p(mask)
|
320 |
+
pos_mask = p(pos_mask)
|
321 |
+
else:
|
322 |
+
if random_clip:
|
323 |
+
start = torch.randint(0, waveform.shape[0] - target_samples, size=())
|
324 |
+
else:
|
325 |
+
start = 0
|
326 |
+
end = start + target_samples
|
327 |
+
waveform = waveform[start:end]
|
328 |
+
mask = mask[start:end]
|
329 |
+
pos_mask = pos_mask[start:end]
|
330 |
+
|
331 |
+
audio_length = min(original_length, target_samples)
|
332 |
+
total_length = target_samples
|
333 |
+
|
334 |
+
if extra_audio_masking:
|
335 |
+
min_size = sample_rate // 2
|
336 |
+
max_size = total_length
|
337 |
+
if original_length > min_size and random.random() > .5:
|
338 |
+
waveform, mask = _get_random_window(waveform, mask, min_size, max_size)
|
339 |
+
|
340 |
+
if audio_aug:
|
341 |
+
import torchaudio_augmentations as AA
|
342 |
+
from torchvision.transforms import RandomApply, Compose
|
343 |
+
|
344 |
+
transform = Compose([
|
345 |
+
RandomApply([AA.PolarityInversion()], p=0.5),
|
346 |
+
RandomApply([AA.Noise(min_snr=0.001, max_snr=0.005)], p=0.2),
|
347 |
+
RandomApply([AA.Gain()], p=0.2),
|
348 |
+
RandomApply([AA.HighLowPass(sample_rate=sample_rate)], p=0.2),
|
349 |
+
RandomApply([AA.PitchShift(n_samples=waveform.shape[-1], sample_rate=sample_rate)], p=0.2),
|
350 |
+
RandomApply([AA.Reverb(sample_rate=sample_rate)], p=0.2)
|
351 |
+
])
|
352 |
+
waveform = transform(waveform.unsqueeze(0)).squeeze(0)
|
353 |
+
|
354 |
+
if return_spec:
|
355 |
+
spectrogram = torchaudio.compliance.kaldi.fbank(
|
356 |
+
waveform.unsqueeze(0) - waveform.mean(),
|
357 |
+
htk_compat=True,
|
358 |
+
sample_frequency=sample_rate,
|
359 |
+
use_energy=False,
|
360 |
+
window_type='hanning',
|
361 |
+
num_mel_bins=spec_mel_bins,
|
362 |
+
dither=0.0,
|
363 |
+
frame_shift=10)
|
364 |
+
|
365 |
+
spectrogram = ((spectrogram - spec_mean) / spec_std).unsqueeze(0)
|
366 |
+
else:
|
367 |
+
spectrogram = None
|
368 |
+
|
369 |
+
if mask.mean() < .04:
|
370 |
+
print(f"Bad entry: {mask.mean()}")
|
371 |
+
|
372 |
+
return waveform, spectrogram, audio_length, total_length, original_length, mask, pos_mask
|
373 |
+
|
374 |
+
|
375 |
+
class ToTargetTensor(object):
|
376 |
+
def __call__(self, target):
|
377 |
+
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
|
378 |
+
|
379 |
+
|
380 |
+
def show_heatmap(ax,
|
381 |
+
image,
|
382 |
+
heatmap,
|
383 |
+
cmap="bwr",
|
384 |
+
color=False,
|
385 |
+
center=False,
|
386 |
+
show_negative=False,
|
387 |
+
cax=None,
|
388 |
+
vmax=None,
|
389 |
+
vmin=None):
|
390 |
+
frame = []
|
391 |
+
|
392 |
+
if color:
|
393 |
+
frame.append(ax.imshow(image))
|
394 |
+
else:
|
395 |
+
bw = np.dot(np.array(image)[..., :3] / 255, [0.2989, 0.5870, 0.1140])
|
396 |
+
bw = np.ones_like(image) * np.expand_dims(bw, -1)
|
397 |
+
frame.append(ax.imshow(bw))
|
398 |
+
|
399 |
+
if center:
|
400 |
+
heatmap -= heatmap.mean()
|
401 |
+
|
402 |
+
if not show_negative:
|
403 |
+
heatmap = heatmap.clamp_min(0)
|
404 |
+
|
405 |
+
heatmap = F.interpolate(heatmap.unsqueeze(0).unsqueeze(0), (image.shape[0], image.shape[1])) \
|
406 |
+
.squeeze(0).squeeze(0)
|
407 |
+
|
408 |
+
if vmax is None:
|
409 |
+
vmax = np.abs(heatmap).max()
|
410 |
+
if vmin is None:
|
411 |
+
vmin = -vmax
|
412 |
+
|
413 |
+
hm = ax.imshow(heatmap, alpha=.5, cmap=cmap, vmax=vmax, vmin=vmin)
|
414 |
+
if cax is not None:
|
415 |
+
plt.colorbar(hm, cax=cax, orientation='vertical')
|
416 |
+
|
417 |
+
frame.extend([hm])
|
418 |
+
return frame
|
419 |
+
|
420 |
+
|
421 |
+
class TorchPCA(object):
|
422 |
+
|
423 |
+
def __init__(self, n_components):
|
424 |
+
self.n_components = n_components
|
425 |
+
|
426 |
+
def fit(self, X):
|
427 |
+
self.mean_ = X.mean(dim=0)
|
428 |
+
unbiased = X - self.mean_.unsqueeze(0)
|
429 |
+
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
|
430 |
+
self.components_ = V.T
|
431 |
+
self.singular_values_ = S
|
432 |
+
return self
|
433 |
+
|
434 |
+
def transform(self, X):
|
435 |
+
t0 = X - self.mean_.unsqueeze(0)
|
436 |
+
projected = t0 @ self.components_.T
|
437 |
+
return projected
|
438 |
+
|
439 |
+
|
440 |
+
def pca(image_feats_list, dim=3, fit_pca=None):
|
441 |
+
device = image_feats_list[0].device
|
442 |
+
|
443 |
+
def flatten(tensor, target_size=None):
|
444 |
+
if target_size is not None and fit_pca is None:
|
445 |
+
F.interpolate(tensor, (target_size, target_size), mode="bilinear")
|
446 |
+
B, C, H, W = tensor.shape
|
447 |
+
return feats.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
|
448 |
+
|
449 |
+
if len(image_feats_list) > 1 and fit_pca is None:
|
450 |
+
target_size = image_feats_list[0].shape[2]
|
451 |
+
else:
|
452 |
+
target_size = None
|
453 |
+
|
454 |
+
flattened_feats = []
|
455 |
+
for feats in image_feats_list:
|
456 |
+
flattened_feats.append(flatten(feats, target_size))
|
457 |
+
x = torch.cat(flattened_feats, dim=0)
|
458 |
+
|
459 |
+
if fit_pca is None:
|
460 |
+
# fit_pca = PCA(n_components=dim, svd_solver='arpack').fit(np.nan_to_num(x.detach().numpy()))
|
461 |
+
fit_pca = TorchPCA(n_components=dim).fit(x)
|
462 |
+
|
463 |
+
reduced_feats = []
|
464 |
+
for feats in image_feats_list:
|
465 |
+
# x_red = torch.from_numpy(fit_pca.transform(flatten(feats)))
|
466 |
+
x_red = fit_pca.transform(flatten(feats))
|
467 |
+
x_red -= x_red.min(dim=0, keepdim=True).values
|
468 |
+
x_red /= x_red.max(dim=0, keepdim=True).values
|
469 |
+
B, C, H, W = feats.shape
|
470 |
+
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device))
|
471 |
+
|
472 |
+
return reduced_feats, fit_pca
|
473 |
+
|
474 |
+
|
475 |
+
def merge_col(fig, axes, col):
|
476 |
+
gs = axes[0, col].get_gridspec()
|
477 |
+
for ax in axes[:, col]:
|
478 |
+
ax.remove()
|
479 |
+
return fig.add_subplot(gs[:, col])
|
480 |
+
|
481 |
+
|
482 |
+
def visualize_av_features(
|
483 |
+
audio,
|
484 |
+
video,
|
485 |
+
feat_a,
|
486 |
+
feat_v,
|
487 |
+
att_a,
|
488 |
+
n_frames,
|
489 |
+
norm_before_pca=True,
|
490 |
+
axes=None,
|
491 |
+
fig=None,
|
492 |
+
modify_fig=True,
|
493 |
+
video_time=0,
|
494 |
+
fit_pca=None
|
495 |
+
):
|
496 |
+
assert (len(audio.shape) == 3) # C, F, T
|
497 |
+
assert (len(video.shape) == 4) # T, C, H, W
|
498 |
+
assert (len(feat_a.shape) == 2) # C, T
|
499 |
+
assert (len(feat_v.shape) == 4) # T, C, H, W
|
500 |
+
assert (len(att_a.shape) == 2) # F, T
|
501 |
+
|
502 |
+
ac, af, at = audio.shape
|
503 |
+
fac, fat = feat_a.shape
|
504 |
+
|
505 |
+
if modify_fig:
|
506 |
+
if axes is None:
|
507 |
+
fig, axes = plt.subplots(3, 3, figsize=(5 * 3, 5))
|
508 |
+
fig.tight_layout()
|
509 |
+
|
510 |
+
bigax1 = merge_col(fig, axes, 0)
|
511 |
+
bigax2 = merge_col(fig, axes, 1)
|
512 |
+
_remove_axes(bigax1)
|
513 |
+
_remove_axes(bigax2)
|
514 |
+
remove_axes(axes[:, 2])
|
515 |
+
else:
|
516 |
+
bigax1 = fig.axes[-2]
|
517 |
+
bigax2 = fig.axes[-1]
|
518 |
+
|
519 |
+
frame_v = unnorm(video).permute(0, 2, 3, 1).detach().cpu()
|
520 |
+
frame_v -= frame_v.min()
|
521 |
+
frame_v /= frame_v.max()
|
522 |
+
|
523 |
+
frame_a = audio.detach().cpu()
|
524 |
+
frame_a -= frame_a.min()
|
525 |
+
frame_a /= frame_a.max()
|
526 |
+
|
527 |
+
if norm_before_pca:
|
528 |
+
[red_feat_v], fit_pca = pca([F.normalize(feat_v, dim=1)], fit_pca=fit_pca)
|
529 |
+
[red_feat_a], _ = pca([F.normalize(feat_a.unsqueeze(0).unsqueeze(-1), dim=1)], fit_pca=fit_pca)
|
530 |
+
else:
|
531 |
+
[red_feat_v], fit_pca = pca([feat_v], fit_pca=fit_pca)
|
532 |
+
[red_feat_a], _ = pca([feat_a.unsqueeze(0).unsqueeze(-1)], fit_pca=fit_pca)
|
533 |
+
|
534 |
+
red_feat_v = red_feat_v.permute(0, 2, 3, 1).detach().cpu()
|
535 |
+
red_feat_a = red_feat_a.permute(0, 2, 3, 1)[0].detach().cpu()
|
536 |
+
|
537 |
+
if red_feat_a.shape[0] == 1:
|
538 |
+
new_height = int((frame_a.shape[0] / frame_a.shape[1]) * red_feat_a.shape[1])
|
539 |
+
red_feat_a = torch.broadcast_to(
|
540 |
+
red_feat_a, (new_height, red_feat_a.shape[1], red_feat_a.shape[2]))
|
541 |
+
plt_att_a = torch.broadcast_to(att_a, (new_height, att_a.shape[1]))
|
542 |
+
else:
|
543 |
+
plt_att_a = att_a
|
544 |
+
|
545 |
+
frac_signal = n_frames / fat
|
546 |
+
n_at = int(at * frac_signal)
|
547 |
+
|
548 |
+
return [bigax1.imshow(frame_v[video_time]),
|
549 |
+
bigax2.imshow(red_feat_v[video_time]),
|
550 |
+
axes[0, 2].imshow(frame_a[:, :n_at]),
|
551 |
+
axes[0, 2].set_title("Spectrogram"),
|
552 |
+
axes[1, 2].imshow(red_feat_a[:, :n_frames]),
|
553 |
+
axes[1, 2].set_title("Audio Features"),
|
554 |
+
axes[2, 2].imshow(plt_att_a[:, :n_frames], vmin=0),
|
555 |
+
axes[2, 2].set_title("Audio Attention")], fig, fit_pca
|
556 |
+
|
557 |
+
|
558 |
+
def create_label_tensor(labels, starts, ends, max_time, n_steps):
|
559 |
+
assert isinstance(starts, torch.Tensor)
|
560 |
+
assert isinstance(ends, torch.Tensor)
|
561 |
+
|
562 |
+
ends[ends < 0] = max_time
|
563 |
+
fps = n_steps / max_time
|
564 |
+
times = (torch.arange(0, n_steps, device=labels.device, dtype=torch.float32) + .5) / fps
|
565 |
+
after_start = starts.unsqueeze(1) <= times.unsqueeze(0)
|
566 |
+
before_end = ends.unsqueeze(1) >= times.unsqueeze(0)
|
567 |
+
# Find when you are inside of a word
|
568 |
+
in_word = (after_start * before_end)
|
569 |
+
# Find which word you are inside of
|
570 |
+
word_to_use = in_word.to(torch.float32).argmax(0)
|
571 |
+
# Get the label for that word, or mask out the label if in no word
|
572 |
+
final_labels = labels[word_to_use] * in_word.any(0).reshape(-1, 1, 1)
|
573 |
+
return final_labels
|
574 |
+
|
575 |
+
|
576 |
+
def generate_subset(n, batch, seed=0):
|
577 |
+
np.random.seed(seed)
|
578 |
+
return np.random.permutation(n)[:batch]
|
579 |
+
|
580 |
+
|
581 |
+
def channel_blur(t, window=5, std_dev=1):
|
582 |
+
tb, tc, th, tw = t.shape
|
583 |
+
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
|
584 |
+
k = torch.exp((-x ** 2 / (2 * std_dev ** 2)))
|
585 |
+
k = k / k.sum()
|
586 |
+
pad = window // 2
|
587 |
+
t_pad = F.pad(t, [0, 0, 0, 0, pad, pad], mode="replicate")
|
588 |
+
tpb, tpc, tph, tpw = t_pad.shape
|
589 |
+
flattened_t = t_pad.permute(0, 2, 3, 1).reshape(tpb * tph * tpw, 1, -1)
|
590 |
+
return F.conv1d(flattened_t, k.reshape(1, 1, window)).reshape(tpb, tph, tpw, tc).permute(0, 3, 1, 2)
|
591 |
+
|
592 |
+
|
593 |
+
def time_blur(t, window=5, std_dev=1):
|
594 |
+
tb, tc, tt = t.shape
|
595 |
+
with torch.no_grad():
|
596 |
+
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
|
597 |
+
k = torch.exp((-x ** 2 / (2 * std_dev ** 2)))
|
598 |
+
k = k / k.sum()
|
599 |
+
k = k.reshape(1, 1, window).detach()
|
600 |
+
pad = window // 2
|
601 |
+
t_pad = F.pad(t, [pad, pad], mode="replicate")
|
602 |
+
return F.conv1d(t_pad.reshape(tb * tc, 1, -1), k).reshape(tb, tc, tt)
|
603 |
+
|
604 |
+
|
605 |
+
def create_model_from_cfg(clazz, cfg, extra_args):
|
606 |
+
import inspect
|
607 |
+
expected_args = inspect.getfullargspec(clazz.__init__).args[1:]
|
608 |
+
new_args = {k: v for k, v in {**cfg, **extra_args}.items() if k in expected_args}
|
609 |
+
return clazz(**new_args)
|
610 |
+
|
611 |
+
|
612 |
+
def load_trained_model(chkpt_dir, extra_args, strict=True):
|
613 |
+
from train_av_alignment import LitAVAligner
|
614 |
+
model = LitAVAligner.load_from_checkpoint(chkpt_dir, **extra_args, strict=strict).cuda()
|
615 |
+
return model
|
616 |
+
|
617 |
+
|
618 |
+
def flatten(l):
|
619 |
+
return [item for sublist in l for item in sublist]
|
620 |
+
|
621 |
+
|
622 |
+
def flatten_preds(preds):
|
623 |
+
results = {}
|
624 |
+
for k in preds[0].keys():
|
625 |
+
if k == "caption_labels":
|
626 |
+
continue
|
627 |
+
if isinstance(preds[0][k], torch.Tensor):
|
628 |
+
results[k] = torch.cat([p[k] for p in preds], dim=0)
|
629 |
+
if "caption" in preds[0]:
|
630 |
+
results["caption"] = flatten([p["caption"] for p in preds])
|
631 |
+
|
632 |
+
if "metadata" in preds[0]:
|
633 |
+
results["frame_files"] = flatten([list(p["metadata"]["frame_files"][0]) for p in preds])
|
634 |
+
results["audio_file"] = flatten([list(p["metadata"]["audio_file"]) for p in preds])
|
635 |
+
results["id"] = flatten([list(p["metadata"]["id"]) for p in preds])
|
636 |
+
results["index"] = torch.tensor(flatten([list(p["metadata"]["index"]) for p in preds]))
|
637 |
+
|
638 |
+
return results
|
639 |
+
|
640 |
+
|
641 |
+
def batch(iterable, n=1):
|
642 |
+
l = len(iterable)
|
643 |
+
for ndx in range(0, l, n):
|
644 |
+
yield iterable[ndx:min(ndx + n, l)]
|
645 |
+
|
646 |
+
|
647 |
+
class GatherLayer(torch.autograd.Function):
|
648 |
+
"""Gather tensors from all process, supporting backward propagation."""
|
649 |
+
|
650 |
+
@staticmethod
|
651 |
+
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
|
652 |
+
pass
|
653 |
+
|
654 |
+
@staticmethod
|
655 |
+
def forward(ctx, inputs):
|
656 |
+
ctx.save_for_backward(inputs)
|
657 |
+
output = [torch.zeros_like(inputs) for _ in range(dist.get_world_size())]
|
658 |
+
dist.all_gather(output, inputs)
|
659 |
+
return tuple(output)
|
660 |
+
|
661 |
+
@staticmethod
|
662 |
+
def backward(ctx, *grads):
|
663 |
+
(inputs,) = ctx.saved_tensors
|
664 |
+
grad_out = torch.zeros_like(inputs)
|
665 |
+
grad_out[:] = grads[dist.get_rank()]
|
666 |
+
return grad_out
|
667 |
+
|
668 |
+
|
669 |
+
class RollingAvg:
|
670 |
+
|
671 |
+
def __init__(self, length, nonzero=False):
|
672 |
+
self.length = length
|
673 |
+
self.nonzero = nonzero
|
674 |
+
self.metrics = defaultdict(lambda: deque(maxlen=self.length))
|
675 |
+
|
676 |
+
def add(self, name, metric):
|
677 |
+
if self.nonzero and metric == 0:
|
678 |
+
return
|
679 |
+
if isinstance(metric, torch.Tensor):
|
680 |
+
metric = metric.detach()
|
681 |
+
|
682 |
+
self.metrics[name].append(metric)
|
683 |
+
|
684 |
+
def get(self, name):
|
685 |
+
with torch.no_grad():
|
686 |
+
return torch.tensor(list(self.metrics[name])).mean()
|
687 |
+
|
688 |
+
def get_all(self):
|
689 |
+
return {k: self.get(k) for k in self.metrics.keys()}
|
690 |
+
|
691 |
+
def add_all(self, values):
|
692 |
+
for k, v in values.items():
|
693 |
+
self.add(k, v)
|
694 |
+
|
695 |
+
def logall(self, log_func):
|
696 |
+
for k in self.metrics.keys():
|
697 |
+
log_func(k, self.get(k))
|
698 |
+
|
699 |
+
|
700 |
+
def gaussian_kernel(k, sigma):
|
701 |
+
kernel = torch.tensor([math.exp(-0.5 * (x - (k // 2)) ** 2 / sigma ** 2) for x in range(k)], dtype=torch.float32)
|
702 |
+
kernel /= kernel.sum() # Normalize the kernel
|
703 |
+
return kernel
|
704 |
+
|
705 |
+
|
706 |
+
def blur_dim(t, window=5, std_dev=1, dim=-1):
|
707 |
+
shape = t.shape
|
708 |
+
n_dims = len(shape)
|
709 |
+
|
710 |
+
# Create the Gaussian kernel
|
711 |
+
with torch.no_grad():
|
712 |
+
x = torch.linspace(-2, 2, window, device=t.device, dtype=torch.float32)
|
713 |
+
k = torch.exp(-x ** 2 / (2 * std_dev ** 2))
|
714 |
+
k = k / k.sum()
|
715 |
+
k = k.view(1, 1, window).detach()
|
716 |
+
|
717 |
+
# Calculate padding
|
718 |
+
pad = window // 2
|
719 |
+
|
720 |
+
# Move the target dimension to the end
|
721 |
+
permute_order = list(range(n_dims))
|
722 |
+
permute_order.append(permute_order.pop(dim))
|
723 |
+
t_permuted = t.permute(permute_order)
|
724 |
+
|
725 |
+
# Flatten all dimensions except the last one
|
726 |
+
new_shape = (-1, t_permuted.size(-1))
|
727 |
+
t_flattened = t_permuted.reshape(new_shape)
|
728 |
+
|
729 |
+
# Pad the tensor
|
730 |
+
t_padded = F.pad(t_flattened.unsqueeze(1), (pad, pad), mode="replicate")
|
731 |
+
|
732 |
+
# Apply convolution
|
733 |
+
blurred = F.conv1d(t_padded, k)
|
734 |
+
|
735 |
+
# Reshape back to original
|
736 |
+
blurred = blurred.squeeze(1).reshape(*t_permuted.shape)
|
737 |
+
blurred = blurred.permute([permute_order.index(i) for i in range(n_dims)])
|
738 |
+
|
739 |
+
return blurred
|
DenseAV/denseav/train.py
ADDED
@@ -0,0 +1,1222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import deque
|
3 |
+
from itertools import combinations
|
4 |
+
from os.path import join
|
5 |
+
|
6 |
+
import hydra
|
7 |
+
import numpy as np
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
import torch
|
10 |
+
import torch.distributed as dist
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from omegaconf import DictConfig, OmegaConf
|
13 |
+
from peft import get_peft_model, LoraConfig
|
14 |
+
from pytorch_lightning import Trainer
|
15 |
+
from pytorch_lightning import seed_everything
|
16 |
+
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
|
17 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
18 |
+
from pytorch_lightning.utilities import grad_norm
|
19 |
+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR
|
20 |
+
from torchmetrics.functional.classification import binary_average_precision
|
21 |
+
|
22 |
+
from huggingface_hub import PyTorchModelHubMixin
|
23 |
+
|
24 |
+
from denseav.aggregators import get_aggregator
|
25 |
+
from denseav.aligners import get_aligner, ProgressiveGrowing
|
26 |
+
from denseav.constants import *
|
27 |
+
from denseav.data.AVDatasets import AVDataModule
|
28 |
+
from denseav.shared import flatten_preds, GatherLayer, \
|
29 |
+
get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg
|
30 |
+
|
31 |
+
torch.multiprocessing.set_sharing_strategy('file_system')
|
32 |
+
|
33 |
+
|
34 |
+
def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor):
|
35 |
+
mask = (true_indices == samples).to(torch.int64)
|
36 |
+
n = mask.shape[0]
|
37 |
+
|
38 |
+
if not mask.any():
|
39 |
+
return samples
|
40 |
+
else:
|
41 |
+
new_samples = torch.randint(0, n, size=(n,), device=true_indices.device)
|
42 |
+
comb_samples = mask * new_samples + (1 - mask) * samples
|
43 |
+
return _imposter_indices_helper(true_indices, comb_samples)
|
44 |
+
|
45 |
+
|
46 |
+
def imposter_indices(n, device):
|
47 |
+
return _imposter_indices_helper(
|
48 |
+
torch.arange(0, n, device=device),
|
49 |
+
torch.randint(0, n, size=(n,), device=device))
|
50 |
+
|
51 |
+
|
52 |
+
def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type):
|
53 |
+
max_t = audio_outputs.shape[-1]
|
54 |
+
oh = F.one_hot(n_frames - 1, num_classes=max_t)
|
55 |
+
audio_mask = 1 - torch.cumsum(oh, dim=1)
|
56 |
+
audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype)
|
57 |
+
|
58 |
+
full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs)
|
59 |
+
expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1)
|
60 |
+
|
61 |
+
if sim_type.endswith("mi"):
|
62 |
+
offset = 10 * (full_sim.max() - full_sim.min())
|
63 |
+
full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values
|
64 |
+
|
65 |
+
if sim_type.startswith("mi"):
|
66 |
+
full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values
|
67 |
+
|
68 |
+
if sim_type.endswith("sa"):
|
69 |
+
full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True)
|
70 |
+
|
71 |
+
return full_sim.mean(dim=[1, 2, 3])
|
72 |
+
|
73 |
+
|
74 |
+
def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.):
|
75 |
+
"""
|
76 |
+
Computes the triplet margin ranking loss for each anchor image/caption pair
|
77 |
+
The impostor image/caption is randomly sampled from the minibatch
|
78 |
+
"""
|
79 |
+
assert (image_outputs.dim() == 4)
|
80 |
+
assert (audio_outputs.dim() == 3)
|
81 |
+
n = image_outputs.size(0)
|
82 |
+
imp_ind_i = imposter_indices(n, image_outputs.device)
|
83 |
+
imp_ind_a = imposter_indices(n, image_outputs.device)
|
84 |
+
true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type)
|
85 |
+
imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type)
|
86 |
+
imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type)
|
87 |
+
a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0)
|
88 |
+
i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0)
|
89 |
+
return (a2i_loss + i2a_loss).mean() / 2
|
90 |
+
|
91 |
+
|
92 |
+
class SimilarityCalibrator(torch.nn.Module):
|
93 |
+
|
94 |
+
def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False):
|
95 |
+
super().__init__()
|
96 |
+
self.max_w = max_w
|
97 |
+
self.min_w = min_w
|
98 |
+
self.w = torch.nn.Parameter(torch.tensor([cal_init]).log())
|
99 |
+
|
100 |
+
self.use_bias = use_bias
|
101 |
+
if self.use_bias:
|
102 |
+
self.b = torch.nn.Parameter(torch.tensor([0.0]))
|
103 |
+
|
104 |
+
self.subtract_mean = subtract_mean
|
105 |
+
|
106 |
+
def get_w(self):
|
107 |
+
return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w)
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
sims = self.get_w() * x
|
111 |
+
|
112 |
+
if self.use_bias:
|
113 |
+
sims = sims + self.b
|
114 |
+
|
115 |
+
if self.subtract_mean:
|
116 |
+
return sims - sims.mean()
|
117 |
+
else:
|
118 |
+
return sims
|
119 |
+
|
120 |
+
|
121 |
+
class SpatialDropout(torch.nn.Module):
|
122 |
+
|
123 |
+
def __init__(self, p, *args, **kwargs):
|
124 |
+
super().__init__(*args, **kwargs)
|
125 |
+
self.p = p
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
b, c, h, w = x.shape
|
129 |
+
dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p
|
130 |
+
|
131 |
+
if self.training:
|
132 |
+
return x * dropout
|
133 |
+
else:
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]):
|
138 |
+
def __init__(self,
|
139 |
+
code_dim,
|
140 |
+
image_model_type,
|
141 |
+
image_model_token_type,
|
142 |
+
image_aligner_type,
|
143 |
+
image_pool_width,
|
144 |
+
audio_model_type,
|
145 |
+
audio_aligner_type,
|
146 |
+
audio_pool_width,
|
147 |
+
audio_lora,
|
148 |
+
audio_lora_rank,
|
149 |
+
image_lora,
|
150 |
+
image_lora_rank,
|
151 |
+
gradient_clipping,
|
152 |
+
learn_audio_cls,
|
153 |
+
silence_l1,
|
154 |
+
silence_l2,
|
155 |
+
tv_weight,
|
156 |
+
nonneg_sim,
|
157 |
+
nonneg_pressure,
|
158 |
+
pretrain_lr,
|
159 |
+
lr,
|
160 |
+
lr_warmup,
|
161 |
+
lr_schedule,
|
162 |
+
lr_cycle_length,
|
163 |
+
optimizer,
|
164 |
+
gather_tensors,
|
165 |
+
sim_agg_type,
|
166 |
+
sim_agg_heads,
|
167 |
+
sim_use_cls,
|
168 |
+
disentangle_weight,
|
169 |
+
norm_vectors,
|
170 |
+
cal_init,
|
171 |
+
cal_balance_weight,
|
172 |
+
loss_type,
|
173 |
+
loss_margin,
|
174 |
+
mask_silence,
|
175 |
+
finetune_image_model,
|
176 |
+
finetune_audio_model,
|
177 |
+
use_cached_embs,
|
178 |
+
output_root,
|
179 |
+
neg_audio,
|
180 |
+
neg_audio_weight,
|
181 |
+
head_agg,
|
182 |
+
adaptive_clipping,
|
183 |
+
specialization_weight,
|
184 |
+
spatial_dropout,
|
185 |
+
channel_dropout,
|
186 |
+
mixup_weight,
|
187 |
+
memory_buffer_size,
|
188 |
+
loss_leak,
|
189 |
+
):
|
190 |
+
super().__init__()
|
191 |
+
|
192 |
+
self.code_dim = code_dim
|
193 |
+
self.image_model_type = image_model_type
|
194 |
+
self.image_model_token_type = image_model_token_type
|
195 |
+
self.image_aligner_type = image_aligner_type
|
196 |
+
self.image_pool_width = image_pool_width
|
197 |
+
self.audio_model_type = audio_model_type
|
198 |
+
self.audio_aligner_type = audio_aligner_type
|
199 |
+
self.audio_pool_width = audio_pool_width
|
200 |
+
|
201 |
+
self.gradient_clipping = gradient_clipping
|
202 |
+
self.learn_audio_cls = learn_audio_cls
|
203 |
+
self.silence_l1 = silence_l1
|
204 |
+
self.silence_l2 = silence_l2
|
205 |
+
|
206 |
+
self.tv_weight = tv_weight
|
207 |
+
self.nonneg_sim = nonneg_sim
|
208 |
+
self.nonneg_pressure = nonneg_pressure
|
209 |
+
self.pretrain_lr = pretrain_lr
|
210 |
+
self.lr = lr
|
211 |
+
self.lr_warmup = lr_warmup
|
212 |
+
self.lr_schedule = lr_schedule
|
213 |
+
self.lr_cycle_length = lr_cycle_length
|
214 |
+
self.optimizer = optimizer
|
215 |
+
self.gather_tensors = gather_tensors
|
216 |
+
self.sim_agg_type = sim_agg_type
|
217 |
+
self.sim_agg_heads = sim_agg_heads
|
218 |
+
self.sim_use_cls = sim_use_cls
|
219 |
+
self.disentangle_weight = disentangle_weight
|
220 |
+
|
221 |
+
self.norm_vectors = norm_vectors
|
222 |
+
self.cal_init = cal_init
|
223 |
+
self.cal_balance_weight = cal_balance_weight
|
224 |
+
self.loss_type = loss_type
|
225 |
+
self.loss_margin = loss_margin
|
226 |
+
self.mask_silence = mask_silence
|
227 |
+
self.finetune_image_model = finetune_image_model
|
228 |
+
self.finetune_audio_model = finetune_audio_model
|
229 |
+
self.use_cached_embs = use_cached_embs
|
230 |
+
self.output_root = output_root
|
231 |
+
self.audio_lora = audio_lora
|
232 |
+
self.audio_lora_rank = audio_lora_rank
|
233 |
+
self.image_lora = image_lora
|
234 |
+
self.image_lora_rank = image_lora_rank
|
235 |
+
self.neg_audio = neg_audio
|
236 |
+
self.neg_audio_weight = neg_audio_weight
|
237 |
+
self.head_agg = head_agg
|
238 |
+
|
239 |
+
self.adaptive_clipping = adaptive_clipping
|
240 |
+
self.specialization_weight = specialization_weight
|
241 |
+
self.spatial_dropout = spatial_dropout
|
242 |
+
self.channel_dropout = channel_dropout
|
243 |
+
self.mixup_weight = mixup_weight
|
244 |
+
|
245 |
+
self.memory_buffer_size = memory_buffer_size
|
246 |
+
self.memory_buffer = deque(maxlen=self.memory_buffer_size)
|
247 |
+
self.loss_leak = loss_leak
|
248 |
+
|
249 |
+
self.full_train = False # Added by me
|
250 |
+
|
251 |
+
if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}:
|
252 |
+
self.audio_input = "spec"
|
253 |
+
elif self.audio_model_type == "davenet":
|
254 |
+
self.audio_input = "davenet_spec"
|
255 |
+
elif self.audio_model_type == "fnac":
|
256 |
+
self.audio_input = "fnac_spec"
|
257 |
+
else:
|
258 |
+
self.audio_input = "audio"
|
259 |
+
|
260 |
+
extra_model_args = dict(output_root=output_root)
|
261 |
+
|
262 |
+
self.image_model, _, self.image_feat_dim = get_image_featurizer(
|
263 |
+
image_model_type, token_type=self.image_model_token_type, **extra_model_args)
|
264 |
+
|
265 |
+
self.image_model.eval()
|
266 |
+
if not self.finetune_image_model:
|
267 |
+
for param in self.image_model.parameters():
|
268 |
+
param.requires_grad = False
|
269 |
+
|
270 |
+
if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}:
|
271 |
+
extra_model_args["model"] = self.image_model.model
|
272 |
+
|
273 |
+
if use_cached_embs:
|
274 |
+
_, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
|
275 |
+
else:
|
276 |
+
self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args)
|
277 |
+
|
278 |
+
self.audio_model.eval()
|
279 |
+
if not self.finetune_audio_model:
|
280 |
+
for param in self.audio_model.parameters():
|
281 |
+
param.requires_grad = False
|
282 |
+
|
283 |
+
if self.image_lora:
|
284 |
+
if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}:
|
285 |
+
target_modules = ["qkv"]
|
286 |
+
elif self.image_model_type == "clip":
|
287 |
+
target_modules = ["out_proj"]
|
288 |
+
elif self.image_model_type == "imagebind":
|
289 |
+
target_modules = ["out_proj", "fc1", "fc2"]
|
290 |
+
else:
|
291 |
+
target_modules = ["q", "k", "v"]
|
292 |
+
|
293 |
+
peft_config = LoraConfig(
|
294 |
+
target_modules=target_modules,
|
295 |
+
inference_mode=False,
|
296 |
+
r=image_lora_rank,
|
297 |
+
lora_alpha=32,
|
298 |
+
lora_dropout=0.1
|
299 |
+
)
|
300 |
+
self.image_model = get_peft_model(self.image_model, peft_config)
|
301 |
+
self.image_model.print_trainable_parameters()
|
302 |
+
|
303 |
+
if self.audio_lora:
|
304 |
+
if self.audio_model_type == "hubert":
|
305 |
+
target_modules = ["q_proj", "k_proj", "v_proj"]
|
306 |
+
else:
|
307 |
+
target_modules = ["q", "k", "v"]
|
308 |
+
|
309 |
+
peft_config = LoraConfig(
|
310 |
+
inference_mode=False,
|
311 |
+
target_modules=target_modules,
|
312 |
+
r=audio_lora_rank,
|
313 |
+
lora_alpha=32,
|
314 |
+
lora_dropout=0.1
|
315 |
+
)
|
316 |
+
self.audio_model = get_peft_model(self.audio_model, peft_config)
|
317 |
+
self.audio_model.print_trainable_parameters()
|
318 |
+
|
319 |
+
shared_aligner_args = dict(out_dim=self.code_dim)
|
320 |
+
|
321 |
+
self.audio_aligner = get_aligner(
|
322 |
+
self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args)
|
323 |
+
self.image_aligner = get_aligner(
|
324 |
+
self.image_aligner_type, self.image_feat_dim, **shared_aligner_args)
|
325 |
+
|
326 |
+
if self.loss_type == "nce":
|
327 |
+
self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False)
|
328 |
+
else:
|
329 |
+
self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True)
|
330 |
+
|
331 |
+
if self.learn_audio_cls:
|
332 |
+
self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim))
|
333 |
+
|
334 |
+
if self.spatial_dropout > 0.0:
|
335 |
+
self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout)
|
336 |
+
|
337 |
+
if self.channel_dropout > 0.0:
|
338 |
+
self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout)
|
339 |
+
|
340 |
+
self.sim_agg = get_aggregator(
|
341 |
+
self.sim_agg_type,
|
342 |
+
self.nonneg_sim,
|
343 |
+
self.mask_silence,
|
344 |
+
self.sim_agg_heads,
|
345 |
+
self.head_agg,
|
346 |
+
self.sim_use_cls,
|
347 |
+
dim=self.image_feat_dim
|
348 |
+
)
|
349 |
+
|
350 |
+
self.hparams_logged = False
|
351 |
+
self.rolling_avg = RollingAvg(50)
|
352 |
+
self.grad_avg = RollingAvg(50, nonzero=True)
|
353 |
+
|
354 |
+
self.save_hyperparameters()
|
355 |
+
|
356 |
+
def set_full_train(self, full_train):
|
357 |
+
self.full_train = full_train
|
358 |
+
|
359 |
+
def prep_feats(self, feats, is_audio):
|
360 |
+
|
361 |
+
if not is_audio and self.training and self.image_pool_width > 1:
|
362 |
+
feats = torch.nn.AvgPool2d(self.image_pool_width)(feats)
|
363 |
+
|
364 |
+
if is_audio and self.training and self.audio_pool_width > 1:
|
365 |
+
feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats)
|
366 |
+
|
367 |
+
if self.norm_vectors:
|
368 |
+
feats = F.normalize(feats, dim=1)
|
369 |
+
|
370 |
+
return feats
|
371 |
+
|
372 |
+
def on_before_optimizer_step(self, optimizer, optimizer_idx):
|
373 |
+
norms = grad_norm(self, norm_type=2)
|
374 |
+
avg_grads = self.grad_avg.get_all()
|
375 |
+
params = {
|
376 |
+
f"grad_2.0_norm/{name}": p
|
377 |
+
for name, p in self.named_parameters()
|
378 |
+
if p.grad is not None
|
379 |
+
}
|
380 |
+
|
381 |
+
if self.adaptive_clipping:
|
382 |
+
for k in norms.keys():
|
383 |
+
if k in params:
|
384 |
+
avg_grad = max(avg_grads.get(k, norms[k]), 1e-5)
|
385 |
+
if self.global_step > 10 and norms[k] > avg_grad * 5:
|
386 |
+
print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}")
|
387 |
+
torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5)
|
388 |
+
norms[k] = avg_grad * 5
|
389 |
+
|
390 |
+
if norms[k] > self.gradient_clipping:
|
391 |
+
# print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}")
|
392 |
+
torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping)
|
393 |
+
|
394 |
+
# self.grad_avg.add_all(norms)
|
395 |
+
# self.log_dict(norms)
|
396 |
+
|
397 |
+
def interpolate_mask(self, mask, target_length, discrete):
|
398 |
+
b, t = mask.shape
|
399 |
+
|
400 |
+
mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \
|
401 |
+
.reshape(b, target_length)
|
402 |
+
|
403 |
+
if discrete:
|
404 |
+
mask = mask > 0.01
|
405 |
+
sums = mask.sum(1)
|
406 |
+
all_zeros = torch.where(sums == 0)[0]
|
407 |
+
if len(all_zeros) > 0:
|
408 |
+
print("Fixing a bad mask")
|
409 |
+
for entry in all_zeros:
|
410 |
+
mask[entry, torch.randint(0, target_length - 1, size=())] = True
|
411 |
+
else:
|
412 |
+
return mask
|
413 |
+
return mask
|
414 |
+
|
415 |
+
def forward_audio(self, batch):
|
416 |
+
if self.use_cached_embs:
|
417 |
+
audio_feats = batch["audio_emb"]
|
418 |
+
if "audio_cls" in batch:
|
419 |
+
audio_cls = batch["audio_cls"]
|
420 |
+
else:
|
421 |
+
audio_cls = None
|
422 |
+
else:
|
423 |
+
audio = batch[self.audio_input]
|
424 |
+
|
425 |
+
if self.full_train:
|
426 |
+
audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
|
427 |
+
else:
|
428 |
+
with torch.no_grad():
|
429 |
+
audio_feats, audio_cls = self.audio_model(audio, include_cls=True)
|
430 |
+
|
431 |
+
mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio)
|
432 |
+
pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio)
|
433 |
+
|
434 |
+
if self.learn_audio_cls:
|
435 |
+
assert audio_cls is None
|
436 |
+
audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1]))
|
437 |
+
|
438 |
+
aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls)
|
439 |
+
|
440 |
+
if self.channel_dropout > 0.0:
|
441 |
+
aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats)
|
442 |
+
|
443 |
+
aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True)
|
444 |
+
audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True)
|
445 |
+
audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False)
|
446 |
+
|
447 |
+
ret = {
|
448 |
+
AUDIO_MASK: audio_mask,
|
449 |
+
AUDIO_POS_MASK: audio_pos_mask,
|
450 |
+
AUDIO_FEATS: aligned_audio_feats,
|
451 |
+
}
|
452 |
+
|
453 |
+
if aligned_audio_cls is not None:
|
454 |
+
ret[AUDIO_CLS] = aligned_audio_cls
|
455 |
+
|
456 |
+
return ret
|
457 |
+
|
458 |
+
# @autocast(device_type="cuda", enabled=False)
|
459 |
+
def forward_image(self, batch, max_batch_size=None):
|
460 |
+
|
461 |
+
with torch.no_grad():
|
462 |
+
image = batch[IMAGE_INPUT]
|
463 |
+
b, nf, c, h, w = image.shape
|
464 |
+
image = image.reshape(b * nf, c, h, w)
|
465 |
+
|
466 |
+
if max_batch_size is None:
|
467 |
+
max_batch_size = image.shape[0]
|
468 |
+
|
469 |
+
chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)]
|
470 |
+
|
471 |
+
all_image_feats = []
|
472 |
+
all_image_cls = []
|
473 |
+
|
474 |
+
for chunk in chunks:
|
475 |
+
if self.full_train:
|
476 |
+
image_feats, image_cls = self.image_model(chunk, include_cls=True)
|
477 |
+
else:
|
478 |
+
with torch.no_grad():
|
479 |
+
image_feats, image_cls = self.image_model(chunk, include_cls=True)
|
480 |
+
|
481 |
+
aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls)
|
482 |
+
|
483 |
+
all_image_feats.append(aligned_image_feats)
|
484 |
+
all_image_cls.append(aligned_image_cls)
|
485 |
+
|
486 |
+
# Stitch the chunks back together
|
487 |
+
aligned_image_feats = torch.cat(all_image_feats, dim=0)
|
488 |
+
aligned_image_cls = torch.cat(all_image_cls, dim=0)
|
489 |
+
|
490 |
+
if self.channel_dropout > 0.0:
|
491 |
+
aligned_image_feats = self.channel_dropout_layer(aligned_image_feats)
|
492 |
+
|
493 |
+
if self.spatial_dropout > 0.0:
|
494 |
+
aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats)
|
495 |
+
|
496 |
+
aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False)
|
497 |
+
ret = {IMAGE_FEATS: aligned_image_feats}
|
498 |
+
|
499 |
+
if IMAGE_MASK in batch:
|
500 |
+
with torch.no_grad():
|
501 |
+
mask = batch[IMAGE_MASK]
|
502 |
+
mask = mask.reshape(b * nf, 1, h, w)
|
503 |
+
b, c, h, w = aligned_image_feats.shape
|
504 |
+
mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w))
|
505 |
+
ret[IMAGE_MASK] = mask
|
506 |
+
|
507 |
+
if aligned_image_cls is not None:
|
508 |
+
ret[IMAGE_CLS] = aligned_image_cls
|
509 |
+
|
510 |
+
return ret
|
511 |
+
|
512 |
+
def forward(self, batch):
|
513 |
+
audio_feat_dict = self.forward_audio(batch)
|
514 |
+
image_feat_dict = self.forward_image(batch)
|
515 |
+
return {**image_feat_dict, **audio_feat_dict}
|
516 |
+
|
517 |
+
def contrast_loss(self, sims):
|
518 |
+
b = sims.shape[0]
|
519 |
+
sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin
|
520 |
+
sims_1 = sims
|
521 |
+
sims_2 = sims.permute(1, 0)
|
522 |
+
|
523 |
+
if self.loss_leak > 0.0:
|
524 |
+
id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
|
525 |
+
label_mask = id * (1 - self.loss_leak)
|
526 |
+
label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1)
|
527 |
+
label_mask /= label_mask.sum(dim=1, keepdim=True)
|
528 |
+
else:
|
529 |
+
label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype)
|
530 |
+
|
531 |
+
labels = torch.arange(0, sims.shape[0], device=sims.device)
|
532 |
+
self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean())
|
533 |
+
self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean())
|
534 |
+
|
535 |
+
if self.loss_type == "margin":
|
536 |
+
margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0)
|
537 |
+
margin_loss = margin_loss_tensor.mean()
|
538 |
+
self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean())
|
539 |
+
self.rolling_avg.add(f"loss/margin", margin_loss)
|
540 |
+
return margin_loss
|
541 |
+
elif self.loss_type == "ce":
|
542 |
+
ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \
|
543 |
+
1 / 2 * F.cross_entropy(sims_2, labels)
|
544 |
+
self.rolling_avg.add(f"loss/ce", ce_loss)
|
545 |
+
return ce_loss
|
546 |
+
elif self.loss_type == "bce":
|
547 |
+
bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten())
|
548 |
+
self.rolling_avg.add(f"loss/bce", bce_loss)
|
549 |
+
return bce_loss
|
550 |
+
elif self.loss_type == "nce":
|
551 |
+
nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \
|
552 |
+
1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean()
|
553 |
+
self.rolling_avg.add(f"loss/nce", nce_loss)
|
554 |
+
return nce_loss
|
555 |
+
else:
|
556 |
+
raise ValueError(f"Unknown loss type {self.loss_type}")
|
557 |
+
|
558 |
+
def loss(self, preds):
|
559 |
+
image_feats = preds[IMAGE_FEATS]
|
560 |
+
audio_feats = preds[AUDIO_FEATS]
|
561 |
+
audio_mask = preds[AUDIO_MASK]
|
562 |
+
image_mask = preds[IMAGE_MASK]
|
563 |
+
audio_pos_mask = preds[AUDIO_POS_MASK]
|
564 |
+
if DATA_SOURCE in preds:
|
565 |
+
source = preds[DATA_SOURCE].to(torch.int64)
|
566 |
+
else:
|
567 |
+
source = None
|
568 |
+
|
569 |
+
uncal_sims = self.sim_agg(preds, agg_heads=True)
|
570 |
+
sims = self.sim_cal(uncal_sims)
|
571 |
+
|
572 |
+
_mask = 1 - torch.eye(sims.shape[0], device=sims.device)
|
573 |
+
self.log(f"sim/pos", torch.diag(sims).mean())
|
574 |
+
self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum()))
|
575 |
+
self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean())
|
576 |
+
self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum()))
|
577 |
+
|
578 |
+
b, c, h, w = image_feats.shape
|
579 |
+
b, c, f, t = audio_feats.shape
|
580 |
+
n_samples = 250
|
581 |
+
|
582 |
+
nh = self.sim_agg_heads
|
583 |
+
image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w)
|
584 |
+
audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t)
|
585 |
+
|
586 |
+
def maybe_clamp(t):
|
587 |
+
return t.clamp_min(0) if self.nonneg_sim else t
|
588 |
+
|
589 |
+
paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False)
|
590 |
+
paired_sim = maybe_clamp(paired_sim_raw)
|
591 |
+
|
592 |
+
loss = 0.0
|
593 |
+
|
594 |
+
if self.nonneg_pressure:
|
595 |
+
afb, afk, afc, aff, aft = audio_feats_by_head.shape
|
596 |
+
ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape
|
597 |
+
assert (afb == ifb)
|
598 |
+
|
599 |
+
device = audio_feats_by_head.device
|
600 |
+
random_b = torch.randint(0, afb, size=(n_samples,), device=device)
|
601 |
+
random_t = torch.randint(0, aft, size=(n_samples,), device=device)
|
602 |
+
random_f = torch.randint(0, aff, size=(n_samples,), device=device)
|
603 |
+
random_h = torch.randint(0, ifh, size=(n_samples,), device=device)
|
604 |
+
random_w = torch.randint(0, ifw, size=(n_samples,), device=device)
|
605 |
+
|
606 |
+
random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t]
|
607 |
+
random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w]
|
608 |
+
random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats)
|
609 |
+
|
610 |
+
nonneg_loss = random_sim_raw.clamp_max(0).square().mean()
|
611 |
+
self.rolling_avg.add(f"loss/nonneg", nonneg_loss)
|
612 |
+
loss += nonneg_loss * self.nonneg_pressure
|
613 |
+
|
614 |
+
if self.silence_l1 > 0 or self.silence_l2 > 0:
|
615 |
+
masked_b, masked_t = torch.where(~audio_mask)
|
616 |
+
if len(masked_b) > n_samples:
|
617 |
+
subset = torch.randperm(len(masked_b))[:n_samples]
|
618 |
+
masked_b = masked_b[subset]
|
619 |
+
masked_t = masked_t[subset]
|
620 |
+
|
621 |
+
if len(masked_b) == n_samples:
|
622 |
+
silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c
|
623 |
+
silence_tensor = maybe_clamp(
|
624 |
+
torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats))
|
625 |
+
|
626 |
+
silence_l1_loss = silence_tensor.abs().mean()
|
627 |
+
self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss)
|
628 |
+
loss += silence_l1_loss * self.silence_l1
|
629 |
+
|
630 |
+
silence_l2_loss = silence_tensor.square().mean()
|
631 |
+
self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss)
|
632 |
+
loss += silence_l2_loss * self.silence_l2
|
633 |
+
else:
|
634 |
+
pass
|
635 |
+
|
636 |
+
if self.neg_audio_weight > 0 and self.neg_audio:
|
637 |
+
b, t = audio_pos_mask.shape
|
638 |
+
negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t)
|
639 |
+
negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape)
|
640 |
+
if negative_weight.sum() > 0:
|
641 |
+
neg_audio_loss = (paired_sim.square() * negative_weight).sum() \
|
642 |
+
/ negative_weight.sum().clamp_min(0.1)
|
643 |
+
self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss)
|
644 |
+
self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean())
|
645 |
+
loss += neg_audio_loss * self.neg_audio_weight
|
646 |
+
else:
|
647 |
+
print("WARNING: No negative samples found in batch")
|
648 |
+
|
649 |
+
if self.tv_weight > 0:
|
650 |
+
tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean()
|
651 |
+
self.rolling_avg.add(f"loss/tv", tv_loss)
|
652 |
+
loss += tv_loss * self.tv_weight
|
653 |
+
|
654 |
+
self.log(f"cal/w", self.sim_cal.get_w())
|
655 |
+
if self.cal_balance_weight > 0.0:
|
656 |
+
cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \
|
657 |
+
.clamp_min(0).square().mean()
|
658 |
+
self.rolling_avg.add(f"loss/cal_balance", cal_balance)
|
659 |
+
loss += cal_balance * self.cal_balance_weight
|
660 |
+
|
661 |
+
if self.disentangle_weight > 0.0:
|
662 |
+
assert source is not None
|
663 |
+
assert self.sim_agg_heads % 2 == 0
|
664 |
+
|
665 |
+
dilation = self.sim_agg_heads // 2
|
666 |
+
sources_oh = F.one_hot(source, num_classes=2)
|
667 |
+
b, h = sources_oh.shape
|
668 |
+
sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \
|
669 |
+
.reshape(b, h * dilation).to(paired_sim)
|
670 |
+
disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean()
|
671 |
+
self.rolling_avg.add(f"loss/disentangle", disentangle_loss)
|
672 |
+
loss += disentangle_loss * self.disentangle_weight
|
673 |
+
|
674 |
+
if self.specialization_weight > 0.0 and self.sim_agg_heads > 1:
|
675 |
+
total_specialization_loss = 0.0
|
676 |
+
combos = list(combinations(range(self.sim_agg_heads), 2))
|
677 |
+
for i, j in combos:
|
678 |
+
specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean()
|
679 |
+
total_specialization_loss += specialization_loss_pair
|
680 |
+
avg_specialization_loss = total_specialization_loss / len(combos)
|
681 |
+
self.rolling_avg.add(f"loss/specialize", avg_specialization_loss)
|
682 |
+
loss += avg_specialization_loss * self.specialization_weight
|
683 |
+
|
684 |
+
if self.mixup_weight > 0.0:
|
685 |
+
b, _, h, w = image_mask.shape
|
686 |
+
neg_img_mask = torch.broadcast_to(
|
687 |
+
1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1),
|
688 |
+
paired_sim.shape)
|
689 |
+
image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1)
|
690 |
+
self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss)
|
691 |
+
loss += image_mixup_loss * self.mixup_weight
|
692 |
+
|
693 |
+
sims = sims
|
694 |
+
loss += self.contrast_loss(sims)
|
695 |
+
self.rolling_avg.add(f"loss/total", loss)
|
696 |
+
|
697 |
+
return loss
|
698 |
+
|
699 |
+
def setup_hparams(self):
|
700 |
+
recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10']
|
701 |
+
|
702 |
+
if self.trainer.datamodule.use_extra_val_sets:
|
703 |
+
datasets = ["Places", "AudioSet"]
|
704 |
+
else:
|
705 |
+
datasets = ["Val"]
|
706 |
+
|
707 |
+
heads = ["total"]
|
708 |
+
|
709 |
+
metric_names = [
|
710 |
+
"hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap",
|
711 |
+
"hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou",
|
712 |
+
]
|
713 |
+
for dataset in datasets:
|
714 |
+
for head in heads:
|
715 |
+
for recall in recalls:
|
716 |
+
metric_names.append(f"hp/{dataset}/{head}/{recall}")
|
717 |
+
|
718 |
+
if self.sim_agg_heads == 2:
|
719 |
+
metric_names.extend(["hp/ap_dis", "hp/act_dis"])
|
720 |
+
|
721 |
+
if hasattr(self.trainer, "datamodule"):
|
722 |
+
all_hparams = {**self.hparams, **self.trainer.datamodule.hparams}
|
723 |
+
else:
|
724 |
+
all_hparams = self.hparams
|
725 |
+
|
726 |
+
starting_values = {n: torch.nan for n in metric_names}
|
727 |
+
self.logger.log_hyperparams(all_hparams, starting_values)
|
728 |
+
|
729 |
+
def on_train_start(self):
|
730 |
+
self.setup_hparams()
|
731 |
+
self.hparams_logged = True
|
732 |
+
|
733 |
+
def on_train_batch_start(self, batch, batch_idx):
|
734 |
+
remake_optimizers = False
|
735 |
+
|
736 |
+
if isinstance(self.image_aligner, ProgressiveGrowing):
|
737 |
+
should_remake = self.image_aligner.maybe_change_phase(self.global_step)
|
738 |
+
remake_optimizers = remake_optimizers or should_remake
|
739 |
+
if isinstance(self.audio_aligner, ProgressiveGrowing):
|
740 |
+
should_remake = self.audio_aligner.maybe_change_phase(self.global_step)
|
741 |
+
remake_optimizers = remake_optimizers or should_remake
|
742 |
+
|
743 |
+
if remake_optimizers:
|
744 |
+
raise NotImplementedError()
|
745 |
+
|
746 |
+
def _combine_preds(self, all_preds):
|
747 |
+
temp = {}
|
748 |
+
new_preds = {}
|
749 |
+
|
750 |
+
# Collect tensors for each key into lists
|
751 |
+
for d in all_preds:
|
752 |
+
for key, value in d.items():
|
753 |
+
if isinstance(value, torch.Tensor):
|
754 |
+
if key not in temp:
|
755 |
+
temp[key] = []
|
756 |
+
temp[key].append(value)
|
757 |
+
|
758 |
+
# Concatenate all tensors for each key using a single call to torch.cat
|
759 |
+
for key, tensor_list in temp.items():
|
760 |
+
new_preds[key] = torch.cat(tensor_list)
|
761 |
+
return new_preds
|
762 |
+
|
763 |
+
def training_step(self, batch, batch_idx):
|
764 |
+
assert batch[IMAGE_INPUT].shape[1] == 1
|
765 |
+
|
766 |
+
preds = self.forward(batch)
|
767 |
+
if DATA_SOURCE in batch:
|
768 |
+
preds[DATA_SOURCE] = batch[DATA_SOURCE]
|
769 |
+
|
770 |
+
if self.trainer.world_size > 1 and self.gather_tensors:
|
771 |
+
for k, v in preds.items():
|
772 |
+
new_v = v.contiguous()
|
773 |
+
preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0)
|
774 |
+
|
775 |
+
if self.memory_buffer_size > 0:
|
776 |
+
new_preds = self._combine_preds(list(self.memory_buffer) + [preds])
|
777 |
+
else:
|
778 |
+
new_preds = preds
|
779 |
+
|
780 |
+
loss = self.loss(new_preds)
|
781 |
+
|
782 |
+
if self.memory_buffer_size > 0:
|
783 |
+
self.memory_buffer.append(self._recursive_detach(preds, gather=False))
|
784 |
+
|
785 |
+
if self.trainer.is_global_zero and self.global_step % 50 == 1:
|
786 |
+
writer = self.logger.experiment
|
787 |
+
self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step))
|
788 |
+
|
789 |
+
if self.trainer.scaler is not None:
|
790 |
+
self.log("loss_scale", self.trainer.scaler.get_scale())
|
791 |
+
|
792 |
+
if self.global_step % 10000 == 0 and self.global_step > 0:
|
793 |
+
print("RESETTING TFEVENT FILE")
|
794 |
+
self.logger.experiment.close()
|
795 |
+
self.logger.experiment._get_file_writer()
|
796 |
+
|
797 |
+
return loss
|
798 |
+
|
799 |
+
def on_validation_start(self) -> None:
|
800 |
+
if not self.hparams_logged:
|
801 |
+
self.setup_hparams()
|
802 |
+
self.hparams_logged = True
|
803 |
+
|
804 |
+
def _auto_gather(self, t):
|
805 |
+
if t.dtype == torch.bool:
|
806 |
+
t = t.to(torch.float)
|
807 |
+
|
808 |
+
if self.trainer.num_devices == 1:
|
809 |
+
return t.cpu()
|
810 |
+
|
811 |
+
t = torch.clone(t).contiguous()
|
812 |
+
if self.trainer.is_global_zero:
|
813 |
+
gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())]
|
814 |
+
dist.gather(t, gather_list)
|
815 |
+
return torch.cat(gather_list, dim=0).cpu()
|
816 |
+
else:
|
817 |
+
dist.gather(t)
|
818 |
+
|
819 |
+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
|
820 |
+
|
821 |
+
with torch.no_grad():
|
822 |
+
preds = self.forward(batch)
|
823 |
+
|
824 |
+
ret = {}
|
825 |
+
for k in preds.keys():
|
826 |
+
if k in preds:
|
827 |
+
ret[k] = self._auto_gather(preds[k])
|
828 |
+
|
829 |
+
batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length']
|
830 |
+
for k in batch_keys:
|
831 |
+
if k in batch:
|
832 |
+
ret[k] = self._auto_gather(batch[k])
|
833 |
+
|
834 |
+
if "metadata" in batch:
|
835 |
+
if isinstance(batch["metadata"]["id"], torch.Tensor):
|
836 |
+
ret["id"] = self._auto_gather(batch["metadata"]["id"])
|
837 |
+
ret["index"] = self._auto_gather(batch["metadata"]["index"])
|
838 |
+
|
839 |
+
return ret
|
840 |
+
|
841 |
+
def _calc_recalls(self, sim):
|
842 |
+
top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0)
|
843 |
+
top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0)
|
844 |
+
a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean()
|
845 |
+
i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean()
|
846 |
+
return {'A_r1': a_recall(1),
|
847 |
+
'A_r5': a_recall(5),
|
848 |
+
'A_r10': a_recall(10),
|
849 |
+
'I_r1': i_recall(1),
|
850 |
+
'I_r5': i_recall(5),
|
851 |
+
'I_r10': i_recall(10)}
|
852 |
+
|
853 |
+
def calc_recalls(self, preds, dataset):
|
854 |
+
sim = self.sim_agg.forward_batched(
|
855 |
+
preds=preds,
|
856 |
+
agg_heads=False,
|
857 |
+
batch_size=4,
|
858 |
+
).cpu()
|
859 |
+
|
860 |
+
all_metrics = dict()
|
861 |
+
for k, v in self._calc_recalls(sim.sum(-1)).items():
|
862 |
+
all_metrics[f"hp/{dataset}/total/" + k] = v
|
863 |
+
|
864 |
+
return all_metrics
|
865 |
+
|
866 |
+
def retrieval_validation(self, outputs, dataset_name):
|
867 |
+
if len(outputs) == 0:
|
868 |
+
return
|
869 |
+
|
870 |
+
if self.trainer.is_global_zero:
|
871 |
+
results = flatten_preds(outputs)
|
872 |
+
if not self.trainer.sanity_checking:
|
873 |
+
print(results[IMAGE_FEATS].shape[0])
|
874 |
+
# assert (results[IMAGE_FEATS].shape[0] == 1000)
|
875 |
+
results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu()
|
876 |
+
results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda()
|
877 |
+
if self.sim_use_cls:
|
878 |
+
results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
|
879 |
+
results[AUDIO_CLS] = results[AUDIO_CLS].cuda()
|
880 |
+
|
881 |
+
results[AUDIO_MASK] = results[AUDIO_MASK].cuda()
|
882 |
+
|
883 |
+
recalls = self.calc_recalls(results, dataset_name)
|
884 |
+
|
885 |
+
results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda()
|
886 |
+
|
887 |
+
writer = self.logger.experiment
|
888 |
+
print("here")
|
889 |
+
for name, v in recalls.items():
|
890 |
+
writer.add_scalar(f"{name}", v, self.global_step + 1)
|
891 |
+
|
892 |
+
def semseg_validation(self, speech_preds, sound_preds):
|
893 |
+
|
894 |
+
if self.trainer.is_global_zero:
|
895 |
+
from eval_utils import get_paired_heatmaps
|
896 |
+
def prep_preds(preds, loader):
|
897 |
+
results = flatten_preds(preds)
|
898 |
+
metadata = loader.dataset.metadata
|
899 |
+
ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy()
|
900 |
+
ordered_metadata["order"] = range(len(ordered_metadata))
|
901 |
+
return results, ordered_metadata
|
902 |
+
|
903 |
+
[_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders
|
904 |
+
speech_results, speech_metadata = prep_preds(speech_preds, speech_loader)
|
905 |
+
sound_results, sound_metadata = prep_preds(sound_preds, sound_loader)
|
906 |
+
|
907 |
+
self.sound_metrics, unique_sound_indices = get_paired_heatmaps(
|
908 |
+
self, sound_results, sound_metadata["ade_class_id"], None)
|
909 |
+
|
910 |
+
self.speech_metrics, unique_word_indices = get_paired_heatmaps(
|
911 |
+
self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"])
|
912 |
+
|
913 |
+
writer = self.logger.experiment
|
914 |
+
|
915 |
+
all_metrics = {
|
916 |
+
**{"sound_" + k: v for k, v in self.sound_metrics.items()},
|
917 |
+
**{"speech_" + k: v for k, v in self.speech_metrics.items()},
|
918 |
+
}
|
919 |
+
|
920 |
+
for k, v in all_metrics.items():
|
921 |
+
writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1)
|
922 |
+
|
923 |
+
def disentangle_validation(self, word_preds, sound_preds):
|
924 |
+
|
925 |
+
if len(word_preds) == 0 or len(sound_preds) == 0:
|
926 |
+
return
|
927 |
+
|
928 |
+
if self.trainer.is_global_zero:
|
929 |
+
word_preds = flatten_preds(word_preds)
|
930 |
+
sound_preds = flatten_preds(sound_preds)
|
931 |
+
|
932 |
+
word_scores = self.sim_agg.get_pairwise_sims(
|
933 |
+
word_preds,
|
934 |
+
raw=False,
|
935 |
+
agg_sim=True,
|
936 |
+
agg_heads=False,
|
937 |
+
)
|
938 |
+
|
939 |
+
sound_scores = self.sim_agg.get_pairwise_sims(
|
940 |
+
sound_preds,
|
941 |
+
raw=False,
|
942 |
+
agg_sim=True,
|
943 |
+
agg_heads=False,
|
944 |
+
)
|
945 |
+
|
946 |
+
all_scores = torch.cat([word_scores, sound_scores], dim=0)
|
947 |
+
all_scores -= all_scores.min(dim=0, keepdim=True).values
|
948 |
+
all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001)
|
949 |
+
|
950 |
+
is_words = torch.cat([
|
951 |
+
torch.ones(word_scores.shape[0]),
|
952 |
+
torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool)
|
953 |
+
|
954 |
+
assert all_scores.shape[1] == 2
|
955 |
+
ap_matrix = torch.zeros(2, 2)
|
956 |
+
act_matrix = torch.zeros(2, 2)
|
957 |
+
|
958 |
+
for head in range(2):
|
959 |
+
# writer.add_histogram(f"h{head}_all_scores", all_scores[:, head])
|
960 |
+
for dataset_num in range(2):
|
961 |
+
if dataset_num == 0:
|
962 |
+
labels = is_words
|
963 |
+
else:
|
964 |
+
labels = ~is_words
|
965 |
+
|
966 |
+
ap_matrix[head, dataset_num] = binary_average_precision(
|
967 |
+
all_scores[:, head].cpu(), labels.to(torch.int64).cpu())
|
968 |
+
|
969 |
+
act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean()
|
970 |
+
|
971 |
+
ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]),
|
972 |
+
.5 * (ap_matrix[0, 1] + ap_matrix[1, 0]))
|
973 |
+
|
974 |
+
act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]),
|
975 |
+
.5 * (act_matrix[0, 1] + act_matrix[1, 0]))
|
976 |
+
|
977 |
+
print("AP", ap_matrix)
|
978 |
+
print("AP dis", ap_dis)
|
979 |
+
print("Act", act_matrix)
|
980 |
+
print("Act dis", act_dis)
|
981 |
+
|
982 |
+
writer = self.logger.experiment
|
983 |
+
writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1)
|
984 |
+
writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1)
|
985 |
+
|
986 |
+
def validation_epoch_end(self, outputs) -> None:
|
987 |
+
print("Val end")
|
988 |
+
with torch.no_grad():
|
989 |
+
if self.trainer.datamodule.use_extra_val_sets:
|
990 |
+
if self.sim_agg_heads == 2:
|
991 |
+
self.disentangle_validation(outputs[0], outputs[1])
|
992 |
+
self.retrieval_validation(outputs[0], "Places")
|
993 |
+
self.retrieval_validation(outputs[1], "AudioSet")
|
994 |
+
self.semseg_validation(outputs[2], outputs[3])
|
995 |
+
|
996 |
+
else:
|
997 |
+
print("HERE!")
|
998 |
+
self.retrieval_validation(outputs, "Val")
|
999 |
+
|
1000 |
+
writer = self.logger.experiment
|
1001 |
+
writer.flush()
|
1002 |
+
|
1003 |
+
def _recursive_detach(self, obj, gather=True):
|
1004 |
+
if isinstance(obj, torch.Tensor):
|
1005 |
+
if gather:
|
1006 |
+
return self._auto_gather(obj)
|
1007 |
+
else:
|
1008 |
+
obj.detach()
|
1009 |
+
elif isinstance(obj, dict):
|
1010 |
+
return {k: self._recursive_detach(v, gather) for k, v in obj.items()}
|
1011 |
+
elif isinstance(obj, list):
|
1012 |
+
return [self._recursive_detach(v, gather) for v in obj]
|
1013 |
+
else:
|
1014 |
+
return obj
|
1015 |
+
|
1016 |
+
def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0):
|
1017 |
+
with torch.no_grad():
|
1018 |
+
predictions = {}
|
1019 |
+
for k, v in batch.items():
|
1020 |
+
predictions[k] = self._recursive_detach(v)
|
1021 |
+
for k, v in self.forward(batch).items():
|
1022 |
+
predictions[k] = self._auto_gather(v)
|
1023 |
+
|
1024 |
+
return predictions
|
1025 |
+
|
1026 |
+
def _configure_optimizers(self, full_train, lr):
|
1027 |
+
params = [
|
1028 |
+
*self.audio_aligner.parameters(),
|
1029 |
+
*self.image_aligner.parameters(),
|
1030 |
+
*self.sim_cal.parameters(),
|
1031 |
+
*self.sim_agg.parameters()
|
1032 |
+
]
|
1033 |
+
|
1034 |
+
if (self.finetune_image_model or self.image_lora) and full_train:
|
1035 |
+
params.extend(self.image_model.parameters())
|
1036 |
+
|
1037 |
+
if (self.finetune_audio_model or self.audio_lora) and full_train:
|
1038 |
+
params.extend(self.audio_model.parameters())
|
1039 |
+
|
1040 |
+
if self.learn_audio_cls:
|
1041 |
+
params.append(self.audio_cls)
|
1042 |
+
|
1043 |
+
last_epoch = self.global_step - 1
|
1044 |
+
if self.optimizer == "adam":
|
1045 |
+
opt = torch.optim.Adam(params, lr=lr, eps=1e-7)
|
1046 |
+
elif self.optimizer == "nadam":
|
1047 |
+
opt = torch.optim.NAdam(params, lr=lr, eps=1e-7)
|
1048 |
+
else:
|
1049 |
+
raise ValueError(f"Unknown optimizer {self.optimizer}")
|
1050 |
+
|
1051 |
+
if self.lr_schedule == "sgdr":
|
1052 |
+
scheduler = CosineAnnealingWarmRestarts(
|
1053 |
+
opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch)
|
1054 |
+
else:
|
1055 |
+
scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch)
|
1056 |
+
|
1057 |
+
if self.lr_warmup > 0:
|
1058 |
+
warmup = LambdaLR(
|
1059 |
+
opt,
|
1060 |
+
lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0),
|
1061 |
+
last_epoch=last_epoch,
|
1062 |
+
)
|
1063 |
+
scheduler = SequentialLR(
|
1064 |
+
opt,
|
1065 |
+
schedulers=[warmup, scheduler],
|
1066 |
+
milestones=[self.lr_warmup],
|
1067 |
+
last_epoch=last_epoch)
|
1068 |
+
|
1069 |
+
scheduler = {"scheduler": scheduler, "interval": "step"}
|
1070 |
+
|
1071 |
+
return [opt], [scheduler]
|
1072 |
+
|
1073 |
+
def configure_optimizers(self):
|
1074 |
+
if self.full_train:
|
1075 |
+
return self._configure_optimizers(self.full_train, self.lr)
|
1076 |
+
else:
|
1077 |
+
return self._configure_optimizers(self.full_train, self.pretrain_lr)
|
1078 |
+
|
1079 |
+
|
1080 |
+
@hydra.main(config_path="configs", config_name="av_align.yaml", version_base=None)
|
1081 |
+
def my_app(cfg: DictConfig) -> None:
|
1082 |
+
print(OmegaConf.to_yaml(cfg))
|
1083 |
+
seed_everything(cfg.seed, workers=True)
|
1084 |
+
|
1085 |
+
exp_name = f"{cfg.resume_prefix}"
|
1086 |
+
|
1087 |
+
if cfg.image_model_type == "dino8":
|
1088 |
+
patch_size = 8 * cfg.image_pool_width
|
1089 |
+
elif cfg.image_model_type == "cavmae":
|
1090 |
+
patch_size = 16 * cfg.image_pool_width
|
1091 |
+
elif cfg.image_model_type == "imagebind":
|
1092 |
+
patch_size = 16 * cfg.image_pool_width
|
1093 |
+
elif cfg.image_model_type == "clip":
|
1094 |
+
patch_size = 16 * cfg.image_pool_width
|
1095 |
+
elif cfg.image_model_type == "cavmae-mixed":
|
1096 |
+
patch_size = 16 * cfg.image_pool_width
|
1097 |
+
elif cfg.image_model_type == "dinov2":
|
1098 |
+
patch_size = 14 * cfg.image_pool_width
|
1099 |
+
else:
|
1100 |
+
raise ValueError(f"Unknown patch size for model {cfg.image_model_type}")
|
1101 |
+
|
1102 |
+
datamodule = AVDataModule(
|
1103 |
+
dataset_name=cfg.dataset_name,
|
1104 |
+
load_size=cfg.load_size,
|
1105 |
+
image_aug=cfg.image_aug,
|
1106 |
+
audio_aug=cfg.audio_aug,
|
1107 |
+
extra_audio_masking=cfg.extra_audio_masking,
|
1108 |
+
audio_model_type=cfg.audio_model_type,
|
1109 |
+
pytorch_data_dir=cfg.pytorch_data_dir,
|
1110 |
+
use_cached_embs=cfg.use_cached_embs,
|
1111 |
+
batch_size=cfg.batch_size,
|
1112 |
+
num_workers=cfg.num_workers,
|
1113 |
+
audio_level=cfg.audio_level,
|
1114 |
+
neg_audio=cfg.neg_audio,
|
1115 |
+
use_original_val_set=not cfg.use_extra_val_sets,
|
1116 |
+
use_extra_val_sets=cfg.use_extra_val_sets,
|
1117 |
+
data_for_plotting=False,
|
1118 |
+
quad_mixup=cfg.quad_mixup,
|
1119 |
+
bg_mixup=cfg.bg_mixup,
|
1120 |
+
patch_mixup=cfg.patch_mixup,
|
1121 |
+
patch_size=patch_size
|
1122 |
+
)
|
1123 |
+
datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml)
|
1124 |
+
|
1125 |
+
aligner = create_model_from_cfg(LitAVAligner, cfg, {})
|
1126 |
+
|
1127 |
+
if cfg.starting_weights is not None:
|
1128 |
+
loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu')
|
1129 |
+
state = loaded["state_dict"]
|
1130 |
+
aligner.load_state_dict(state, strict=cfg.load_strict)
|
1131 |
+
del state
|
1132 |
+
del loaded
|
1133 |
+
|
1134 |
+
if cfg.num_gpus > 1:
|
1135 |
+
# strategy = "ddp_sharded" # _find_unused_parameters_true"
|
1136 |
+
strategy = "ddp" # _find_unused_parameters_true"
|
1137 |
+
else:
|
1138 |
+
strategy = "auto"
|
1139 |
+
|
1140 |
+
if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}:
|
1141 |
+
val_args = dict(check_val_every_n_epoch=2)
|
1142 |
+
elif cfg.dataset_name in {"dolphin"}:
|
1143 |
+
val_args = dict(check_val_every_n_epoch=5)
|
1144 |
+
else:
|
1145 |
+
val_args = dict(val_check_interval=10000)
|
1146 |
+
|
1147 |
+
# val_args = dict(val_check_interval=1000)
|
1148 |
+
|
1149 |
+
def maybe_get_ckpt(ckpt_dir):
|
1150 |
+
if cfg.auto_resume and os.path.exists(ckpt_dir):
|
1151 |
+
print(f"Attempting to resume from {ckpt_dir}")
|
1152 |
+
candidates = os.listdir(ckpt_dir)
|
1153 |
+
assert (len(candidates) == 1)
|
1154 |
+
return join(ckpt_dir, candidates[0])
|
1155 |
+
elif cfg.auto_resume:
|
1156 |
+
print(f"Could not find checkpoint at {ckpt_dir}")
|
1157 |
+
return None
|
1158 |
+
else:
|
1159 |
+
return None
|
1160 |
+
|
1161 |
+
log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name)
|
1162 |
+
ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name)
|
1163 |
+
|
1164 |
+
import gc
|
1165 |
+
torch.cuda.empty_cache()
|
1166 |
+
gc.collect()
|
1167 |
+
|
1168 |
+
def run_exp(aligner, full_train):
|
1169 |
+
trainer_args = dict(
|
1170 |
+
accelerator='gpu',
|
1171 |
+
strategy=strategy,
|
1172 |
+
devices=cfg.num_gpus,
|
1173 |
+
num_sanity_val_steps=cfg.num_sanity_val_steps,
|
1174 |
+
log_every_n_steps=50,
|
1175 |
+
reload_dataloaders_every_n_epochs=10,
|
1176 |
+
precision="16",
|
1177 |
+
# profiler="simple",
|
1178 |
+
# precision="bf16",
|
1179 |
+
max_steps=cfg.max_steps,
|
1180 |
+
**val_args)
|
1181 |
+
|
1182 |
+
aligner.set_full_train(full_train)
|
1183 |
+
if full_train:
|
1184 |
+
suffix = "train"
|
1185 |
+
else:
|
1186 |
+
suffix = "pretrain"
|
1187 |
+
trainer_args["max_steps"] = cfg.pretrain_steps
|
1188 |
+
|
1189 |
+
print(f"Starting {suffix} phase")
|
1190 |
+
|
1191 |
+
logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False)
|
1192 |
+
callbacks = [
|
1193 |
+
ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1),
|
1194 |
+
LearningRateMonitor(logging_interval='step'),
|
1195 |
+
]
|
1196 |
+
Trainer(logger=logger,
|
1197 |
+
callbacks=callbacks,
|
1198 |
+
**trainer_args).fit(
|
1199 |
+
aligner,
|
1200 |
+
datamodule=datamodule,
|
1201 |
+
ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix)))
|
1202 |
+
|
1203 |
+
train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train"))
|
1204 |
+
|
1205 |
+
gc.collect()
|
1206 |
+
if torch.cuda.is_available():
|
1207 |
+
torch.cuda.empty_cache()
|
1208 |
+
|
1209 |
+
if cfg.pretrain_steps > 0 and train_chkpt is None:
|
1210 |
+
print("---"*10)
|
1211 |
+
print("Setup with full_train = False")
|
1212 |
+
run_exp(aligner, full_train=False)
|
1213 |
+
print("---"*10)
|
1214 |
+
else:
|
1215 |
+
print("---"*10)
|
1216 |
+
print("Setup with full_train = False")
|
1217 |
+
run_exp(aligner, full_train=True)
|
1218 |
+
print("---"*10)
|
1219 |
+
|
1220 |
+
|
1221 |
+
if __name__ == "__main__":
|
1222 |
+
my_app()
|
DenseAV/gradio_app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import requests
|
7 |
+
import torch
|
8 |
+
import torchvision
|
9 |
+
import torchvision.transforms as T
|
10 |
+
from PIL import Image
|
11 |
+
from featup.util import norm
|
12 |
+
from torchaudio.functional import resample
|
13 |
+
|
14 |
+
from denseav.train import LitAVAligner
|
15 |
+
from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video
|
16 |
+
from denseav.shared import norm, crop_to_divisor, blur_dim
|
17 |
+
from os.path import join
|
18 |
+
|
19 |
+
if __name__ == "__main__":
|
20 |
+
|
21 |
+
mode = "local"
|
22 |
+
|
23 |
+
if mode == "local":
|
24 |
+
sample_videos_dir = "samples"
|
25 |
+
else:
|
26 |
+
os.environ['TORCH_HOME'] = '/tmp/.cache'
|
27 |
+
os.environ['HF_HOME'] = '/tmp/.cache'
|
28 |
+
os.environ['HF_DATASETS_CACHE'] = '/tmp/.cache'
|
29 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/.cache'
|
30 |
+
os.environ['GRADIO_EXAMPLES_CACHE'] = '/tmp/gradio_cache'
|
31 |
+
sample_videos_dir = "/tmp/samples"
|
32 |
+
|
33 |
+
|
34 |
+
def download_video(url, save_path):
|
35 |
+
response = requests.get(url)
|
36 |
+
with open(save_path, 'wb') as file:
|
37 |
+
file.write(response.content)
|
38 |
+
|
39 |
+
|
40 |
+
base_url = "https://marhamilresearch4.blob.core.windows.net/denseav-public/samples/"
|
41 |
+
sample_videos_urls = {
|
42 |
+
"puppies.mp4": base_url + "puppies.mp4",
|
43 |
+
"peppers.mp4": base_url + "peppers.mp4",
|
44 |
+
"boat.mp4": base_url + "boat.mp4",
|
45 |
+
"elephant2.mp4": base_url + "elephant2.mp4",
|
46 |
+
|
47 |
+
}
|
48 |
+
|
49 |
+
# Ensure the directory for sample videos exists
|
50 |
+
os.makedirs(sample_videos_dir, exist_ok=True)
|
51 |
+
|
52 |
+
# Download each sample video
|
53 |
+
for filename, url in sample_videos_urls.items():
|
54 |
+
save_path = os.path.join(sample_videos_dir, filename)
|
55 |
+
# Download the video if it doesn't already exist
|
56 |
+
if not os.path.exists(save_path):
|
57 |
+
print(f"Downloading {filename}...")
|
58 |
+
download_video(url, save_path)
|
59 |
+
else:
|
60 |
+
print(f"{filename} already exists. Skipping download.")
|
61 |
+
|
62 |
+
csv.field_size_limit(100000000)
|
63 |
+
options = ['language', "sound-language", "sound"]
|
64 |
+
load_size = 224
|
65 |
+
plot_size = 224
|
66 |
+
|
67 |
+
video_input = gr.Video(label="Choose a video to featurize", height=480)
|
68 |
+
model_option = gr.Radio(options, value="language", label='Choose a model')
|
69 |
+
|
70 |
+
video_output1 = gr.Video(label="Audio Video Attention", height=480)
|
71 |
+
video_output2 = gr.Video(label="Multi-Head Audio Video Attention (Only Availible for sound_and_language)",
|
72 |
+
height=480)
|
73 |
+
video_output3 = gr.Video(label="Visual Features", height=480)
|
74 |
+
|
75 |
+
models = {o: LitAVAligner.from_pretrained(f"mhamilton723/DenseAV-{o}") for o in options}
|
76 |
+
|
77 |
+
|
78 |
+
def process_video(video, model_option):
|
79 |
+
model = models[model_option].cuda()
|
80 |
+
|
81 |
+
original_frames, audio, info = torchvision.io.read_video(video, end_pts=10, pts_unit='sec')
|
82 |
+
sample_rate = 16000
|
83 |
+
|
84 |
+
if info["audio_fps"] != sample_rate:
|
85 |
+
audio = resample(audio, info["audio_fps"], sample_rate)
|
86 |
+
audio = audio[0].unsqueeze(0)
|
87 |
+
|
88 |
+
img_transform = T.Compose([
|
89 |
+
T.Resize(load_size, Image.BILINEAR),
|
90 |
+
lambda x: crop_to_divisor(x, 8),
|
91 |
+
lambda x: x.to(torch.float32) / 255,
|
92 |
+
norm])
|
93 |
+
|
94 |
+
frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)
|
95 |
+
|
96 |
+
plotting_img_transform = T.Compose([
|
97 |
+
T.Resize(plot_size, Image.BILINEAR),
|
98 |
+
lambda x: crop_to_divisor(x, 8),
|
99 |
+
lambda x: x.to(torch.float32) / 255])
|
100 |
+
|
101 |
+
frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
audio_feats = model.forward_audio({"audio": audio.cuda()})
|
105 |
+
audio_feats = {k: v.cpu() for k, v in audio_feats.items()}
|
106 |
+
image_feats = model.forward_image({"frames": frames.unsqueeze(0).cuda()}, max_batch_size=2)
|
107 |
+
image_feats = {k: v.cpu() for k, v in image_feats.items()}
|
108 |
+
|
109 |
+
sim_by_head = model.sim_agg.get_pairwise_sims(
|
110 |
+
{**image_feats, **audio_feats},
|
111 |
+
raw=False,
|
112 |
+
agg_sim=False,
|
113 |
+
agg_heads=False
|
114 |
+
).mean(dim=-2).cpu()
|
115 |
+
|
116 |
+
sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)
|
117 |
+
print(sim_by_head.shape)
|
118 |
+
|
119 |
+
temp_video_path_1 = tempfile.mktemp(suffix='.mp4')
|
120 |
+
|
121 |
+
plot_attention_video(
|
122 |
+
sim_by_head,
|
123 |
+
frames_to_plot,
|
124 |
+
audio,
|
125 |
+
info["video_fps"],
|
126 |
+
sample_rate,
|
127 |
+
temp_video_path_1)
|
128 |
+
|
129 |
+
if model_option == "sound_and_language":
|
130 |
+
temp_video_path_2 = tempfile.mktemp(suffix='.mp4')
|
131 |
+
|
132 |
+
plot_2head_attention_video(
|
133 |
+
sim_by_head,
|
134 |
+
frames_to_plot,
|
135 |
+
audio,
|
136 |
+
info["video_fps"],
|
137 |
+
sample_rate,
|
138 |
+
temp_video_path_2)
|
139 |
+
|
140 |
+
else:
|
141 |
+
temp_video_path_2 = None
|
142 |
+
|
143 |
+
temp_video_path_3 = tempfile.mktemp(suffix='.mp4')
|
144 |
+
temp_video_path_4 = tempfile.mktemp(suffix='.mp4')
|
145 |
+
|
146 |
+
plot_feature_video(
|
147 |
+
image_feats["image_feats"].cpu(),
|
148 |
+
audio_feats['audio_feats'].cpu(),
|
149 |
+
frames_to_plot,
|
150 |
+
audio,
|
151 |
+
info["video_fps"],
|
152 |
+
sample_rate,
|
153 |
+
temp_video_path_3,
|
154 |
+
temp_video_path_4,
|
155 |
+
)
|
156 |
+
# return temp_video_path_1, temp_video_path_2, temp_video_path_3, temp_video_path_4
|
157 |
+
|
158 |
+
return temp_video_path_1, temp_video_path_2, temp_video_path_3
|
159 |
+
|
160 |
+
|
161 |
+
with gr.Blocks() as demo:
|
162 |
+
with gr.Column():
|
163 |
+
gr.Markdown("## Visualizing Sound and Language with DenseAV")
|
164 |
+
gr.Markdown(
|
165 |
+
"This demo allows you to explore the inner attention maps of DenseAV's dense multi-head contrastive operator.")
|
166 |
+
with gr.Row():
|
167 |
+
with gr.Column(scale=1):
|
168 |
+
model_option.render()
|
169 |
+
with gr.Column(scale=3):
|
170 |
+
video_input.render()
|
171 |
+
with gr.Row():
|
172 |
+
submit_button = gr.Button("Submit")
|
173 |
+
with gr.Row():
|
174 |
+
gr.Examples(
|
175 |
+
examples=[
|
176 |
+
[join(sample_videos_dir, "puppies.mp4"), "sound_and_language"],
|
177 |
+
[join(sample_videos_dir, "peppers.mp4"), "language"],
|
178 |
+
[join(sample_videos_dir, "elephant2.mp4"), "language"],
|
179 |
+
[join(sample_videos_dir, "boat.mp4"), "language"]
|
180 |
+
|
181 |
+
],
|
182 |
+
inputs=[video_input, model_option]
|
183 |
+
)
|
184 |
+
with gr.Row():
|
185 |
+
video_output1.render()
|
186 |
+
video_output2.render()
|
187 |
+
video_output3.render()
|
188 |
+
|
189 |
+
submit_button.click(fn=process_video, inputs=[video_input, model_option],
|
190 |
+
outputs=[video_output1, video_output2, video_output3])
|
191 |
+
|
192 |
+
|
193 |
+
if mode == "local":
|
194 |
+
demo.launch(server_name="0.0.0.0", server_port=6006, debug=True)
|
195 |
+
else:
|
196 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)
|
DenseAV/hubconf.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# hubconf.py
|
2 |
+
from denseav.train import LitAVAligner
|
3 |
+
|
4 |
+
dependencies = ['torch', 'torchvision', 'PIL', 'denseav'] # List any dependencies here
|
5 |
+
|
6 |
+
|
7 |
+
def _load_base(model_name):
|
8 |
+
model = LitAVAligner.load_from_checkpoint(
|
9 |
+
f"https://marhamilresearch4.blob.core.windows.net/denseav-public/hub/{model_name}.ckpt",
|
10 |
+
**{'loss_leak': 0.0, 'use_cached_embs': False},
|
11 |
+
strict=True)
|
12 |
+
model.set_full_train(True)
|
13 |
+
return model
|
14 |
+
|
15 |
+
|
16 |
+
def sound_and_language():
|
17 |
+
return _load_base("denseav_2head")
|
18 |
+
|
19 |
+
|
20 |
+
def language():
|
21 |
+
return _load_base("denseav_language")
|
22 |
+
|
23 |
+
|
24 |
+
def sound():
|
25 |
+
return _load_base("denseav_sound")
|
DenseAV/samples/puppies.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4bc5049010142b9a4364afea7da15d4e9736d95cfc9a365c2658c69ba409d56
|
3 |
+
size 7534432
|
DenseAV/setup.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name='denseav',
|
5 |
+
version='0.1.0',
|
6 |
+
packages=find_packages(),
|
7 |
+
install_requires=[
|
8 |
+
'torch',
|
9 |
+
'kornia',
|
10 |
+
'omegaconf',
|
11 |
+
'pytorch-lightning',
|
12 |
+
'torchvision',
|
13 |
+
'tqdm',
|
14 |
+
'torchmetrics',
|
15 |
+
'scikit-learn',
|
16 |
+
'numpy',
|
17 |
+
'matplotlib',
|
18 |
+
'timm==0.4.12',
|
19 |
+
'moviepy',
|
20 |
+
'hydra-core',
|
21 |
+
'peft==0.5.0',
|
22 |
+
'av',
|
23 |
+
'audioread'
|
24 |
+
],
|
25 |
+
author='Mark Hamilton',
|
26 |
+
author_email='[email protected]',
|
27 |
+
description='Offical code for the CVPR 2024 Paper: Separating the "Chirp" from the "Chat": Self-supervised Visual Grounding of Sound and Language',
|
28 |
+
long_description=open('README.md').read(),
|
29 |
+
long_description_content_type='text/markdown',
|
30 |
+
url='https://github.com/mhamilton723/DenseAV',
|
31 |
+
classifiers=[
|
32 |
+
'Programming Language :: Python :: 3',
|
33 |
+
'License :: OSI Approved :: MIT License',
|
34 |
+
'Operating System :: OS Independent',
|
35 |
+
],
|
36 |
+
python_requires='>=3.6'
|
37 |
+
)
|