LittleLirow commited on
Commit
2b6694e
1 Parent(s): 08ce9fc

Temporarily disable BGM module

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audioldm/__init__.py +0 -3
  2. audioldm/__pycache__/__init__.cpython-310.pyc +0 -0
  3. audioldm/__pycache__/ldm.cpython-310.pyc +0 -0
  4. audioldm/__pycache__/pipeline.cpython-310.pyc +0 -0
  5. audioldm/__pycache__/utils.cpython-310.pyc +0 -0
  6. audioldm/audio/__init__.py +0 -0
  7. audioldm/audio/audio_processing.py +0 -100
  8. audioldm/audio/stft.py +0 -180
  9. audioldm/audio/tools.py +0 -33
  10. audioldm/clap/__init__.py +0 -0
  11. audioldm/clap/__pycache__/__init__.cpython-310.pyc +0 -0
  12. audioldm/clap/__pycache__/encoders.cpython-310.pyc +0 -0
  13. audioldm/clap/encoders.py +0 -170
  14. audioldm/clap/open_clip/__init__.py +0 -25
  15. audioldm/clap/open_clip/__pycache__/__init__.cpython-310.pyc +0 -0
  16. audioldm/clap/open_clip/__pycache__/factory.cpython-310.pyc +0 -0
  17. audioldm/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc +0 -0
  18. audioldm/clap/open_clip/__pycache__/htsat.cpython-310.pyc +0 -0
  19. audioldm/clap/open_clip/__pycache__/loss.cpython-310.pyc +0 -0
  20. audioldm/clap/open_clip/__pycache__/model.cpython-310.pyc +0 -0
  21. audioldm/clap/open_clip/__pycache__/openai.cpython-310.pyc +0 -0
  22. audioldm/clap/open_clip/__pycache__/pann_model.cpython-310.pyc +0 -0
  23. audioldm/clap/open_clip/__pycache__/pretrained.cpython-310.pyc +0 -0
  24. audioldm/clap/open_clip/__pycache__/timm_model.cpython-310.pyc +0 -0
  25. audioldm/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc +0 -0
  26. audioldm/clap/open_clip/__pycache__/transform.cpython-310.pyc +0 -0
  27. audioldm/clap/open_clip/__pycache__/utils.cpython-310.pyc +0 -0
  28. audioldm/clap/open_clip/bert.py +0 -40
  29. audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  30. audioldm/clap/open_clip/factory.py +0 -277
  31. audioldm/clap/open_clip/feature_fusion.py +0 -192
  32. audioldm/clap/open_clip/htsat.py +0 -1308
  33. audioldm/clap/open_clip/linear_probe.py +0 -66
  34. audioldm/clap/open_clip/loss.py +0 -398
  35. audioldm/clap/open_clip/model.py +0 -936
  36. audioldm/clap/open_clip/model_configs/HTSAT-base.json +0 -23
  37. audioldm/clap/open_clip/model_configs/HTSAT-large.json +0 -23
  38. audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +0 -23
  39. audioldm/clap/open_clip/model_configs/HTSAT-tiny.json +0 -23
  40. audioldm/clap/open_clip/model_configs/PANN-10.json +0 -23
  41. audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json +0 -23
  42. audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +0 -23
  43. audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +0 -23
  44. audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json +0 -23
  45. audioldm/clap/open_clip/model_configs/PANN-14.json +0 -23
  46. audioldm/clap/open_clip/model_configs/PANN-6.json +0 -23
  47. audioldm/clap/open_clip/model_configs/RN101-quickgelu.json +0 -22
  48. audioldm/clap/open_clip/model_configs/RN101.json +0 -21
  49. audioldm/clap/open_clip/model_configs/RN50-quickgelu.json +0 -22
  50. audioldm/clap/open_clip/model_configs/RN50.json +0 -21
audioldm/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .ldm import LatentDiffusion
2
- from .utils import seed_everything
3
- from .pipeline import *
 
 
 
 
audioldm/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (269 Bytes)
 
audioldm/__pycache__/ldm.cpython-310.pyc DELETED
Binary file (14.5 kB)
 
audioldm/__pycache__/pipeline.cpython-310.pyc DELETED
Binary file (2.17 kB)
 
audioldm/__pycache__/utils.cpython-310.pyc DELETED
Binary file (4.81 kB)
 
audioldm/audio/__init__.py DELETED
File without changes
audioldm/audio/audio_processing.py DELETED
@@ -1,100 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import librosa.util as librosa_util
4
- from scipy.signal import get_window
5
-
6
-
7
- def window_sumsquare(
8
- window,
9
- n_frames,
10
- hop_length,
11
- win_length,
12
- n_fft,
13
- dtype=np.float32,
14
- norm=None,
15
- ):
16
- """
17
- # from librosa 0.6
18
- Compute the sum-square envelope of a window function at a given hop length.
19
-
20
- This is used to estimate modulation effects induced by windowing
21
- observations in short-time fourier transforms.
22
-
23
- Parameters
24
- ----------
25
- window : string, tuple, number, callable, or list-like
26
- Window specification, as in `get_window`
27
-
28
- n_frames : int > 0
29
- The number of analysis frames
30
-
31
- hop_length : int > 0
32
- The number of samples to advance between frames
33
-
34
- win_length : [optional]
35
- The length of the window function. By default, this matches `n_fft`.
36
-
37
- n_fft : int > 0
38
- The length of each analysis frame.
39
-
40
- dtype : np.dtype
41
- The data type of the output
42
-
43
- Returns
44
- -------
45
- wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
46
- The sum-squared envelope of the window function
47
- """
48
- if win_length is None:
49
- win_length = n_fft
50
-
51
- n = n_fft + hop_length * (n_frames - 1)
52
- x = np.zeros(n, dtype=dtype)
53
-
54
- # Compute the squared window at the desired length
55
- win_sq = get_window(window, win_length, fftbins=True)
56
- win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
57
- win_sq = librosa_util.pad_center(win_sq, n_fft)
58
-
59
- # Fill the envelope
60
- for i in range(n_frames):
61
- sample = i * hop_length
62
- x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
63
- return x
64
-
65
-
66
- def griffin_lim(magnitudes, stft_fn, n_iters=30):
67
- """
68
- PARAMS
69
- ------
70
- magnitudes: spectrogram magnitudes
71
- stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
72
- """
73
-
74
- angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
75
- angles = angles.astype(np.float32)
76
- angles = torch.autograd.Variable(torch.from_numpy(angles))
77
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
78
-
79
- for i in range(n_iters):
80
- _, angles = stft_fn.transform(signal)
81
- signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
82
- return signal
83
-
84
-
85
- def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
86
- """
87
- PARAMS
88
- ------
89
- C: compression factor
90
- """
91
- return normalize_fun(torch.clamp(x, min=clip_val) * C)
92
-
93
-
94
- def dynamic_range_decompression(x, C=1):
95
- """
96
- PARAMS
97
- ------
98
- C: compression factor used to compress
99
- """
100
- return torch.exp(x) / C
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/stft.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import numpy as np
4
- from scipy.signal import get_window
5
- from librosa.util import pad_center, tiny
6
- from librosa.filters import mel as librosa_mel_fn
7
-
8
- from audioldm.audio.audio_processing import (
9
- dynamic_range_compression,
10
- dynamic_range_decompression,
11
- window_sumsquare,
12
- )
13
-
14
-
15
- class STFT(torch.nn.Module):
16
- """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
17
-
18
- def __init__(self, filter_length, hop_length, win_length, window="hann"):
19
- super(STFT, self).__init__()
20
- self.filter_length = filter_length
21
- self.hop_length = hop_length
22
- self.win_length = win_length
23
- self.window = window
24
- self.forward_transform = None
25
- scale = self.filter_length / self.hop_length
26
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
27
-
28
- cutoff = int((self.filter_length / 2 + 1))
29
- fourier_basis = np.vstack(
30
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
31
- )
32
-
33
- forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
34
- inverse_basis = torch.FloatTensor(
35
- np.linalg.pinv(scale * fourier_basis).T[:, None, :]
36
- )
37
-
38
- if window is not None:
39
- assert filter_length >= win_length
40
- # get window and zero center pad it to filter_length
41
- fft_window = get_window(window, win_length, fftbins=True)
42
- fft_window = pad_center(fft_window, filter_length)
43
- fft_window = torch.from_numpy(fft_window).float()
44
-
45
- # window the bases
46
- forward_basis *= fft_window
47
- inverse_basis *= fft_window
48
-
49
- self.register_buffer("forward_basis", forward_basis.float())
50
- self.register_buffer("inverse_basis", inverse_basis.float())
51
-
52
- def transform(self, input_data):
53
- num_batches = input_data.size(0)
54
- num_samples = input_data.size(1)
55
-
56
- self.num_samples = num_samples
57
-
58
- # similar to librosa, reflect-pad the input
59
- input_data = input_data.view(num_batches, 1, num_samples)
60
- input_data = F.pad(
61
- input_data.unsqueeze(1),
62
- (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
63
- mode="reflect",
64
- )
65
- input_data = input_data.squeeze(1)
66
-
67
- forward_transform = F.conv1d(
68
- input_data,
69
- torch.autograd.Variable(self.forward_basis, requires_grad=False),
70
- stride=self.hop_length,
71
- padding=0,
72
- ).cpu()
73
-
74
- cutoff = int((self.filter_length / 2) + 1)
75
- real_part = forward_transform[:, :cutoff, :]
76
- imag_part = forward_transform[:, cutoff:, :]
77
-
78
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
79
- phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
80
-
81
- return magnitude, phase
82
-
83
- def inverse(self, magnitude, phase):
84
- recombine_magnitude_phase = torch.cat(
85
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
86
- )
87
-
88
- inverse_transform = F.conv_transpose1d(
89
- recombine_magnitude_phase,
90
- torch.autograd.Variable(self.inverse_basis, requires_grad=False),
91
- stride=self.hop_length,
92
- padding=0,
93
- )
94
-
95
- if self.window is not None:
96
- window_sum = window_sumsquare(
97
- self.window,
98
- magnitude.size(-1),
99
- hop_length=self.hop_length,
100
- win_length=self.win_length,
101
- n_fft=self.filter_length,
102
- dtype=np.float32,
103
- )
104
- # remove modulation effects
105
- approx_nonzero_indices = torch.from_numpy(
106
- np.where(window_sum > tiny(window_sum))[0]
107
- )
108
- window_sum = torch.autograd.Variable(
109
- torch.from_numpy(window_sum), requires_grad=False
110
- )
111
- window_sum = window_sum
112
- inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
113
- approx_nonzero_indices
114
- ]
115
-
116
- # scale by hop ratio
117
- inverse_transform *= float(self.filter_length) / self.hop_length
118
-
119
- inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
120
- inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
121
-
122
- return inverse_transform
123
-
124
- def forward(self, input_data):
125
- self.magnitude, self.phase = self.transform(input_data)
126
- reconstruction = self.inverse(self.magnitude, self.phase)
127
- return reconstruction
128
-
129
-
130
- class TacotronSTFT(torch.nn.Module):
131
- def __init__(
132
- self,
133
- filter_length,
134
- hop_length,
135
- win_length,
136
- n_mel_channels,
137
- sampling_rate,
138
- mel_fmin,
139
- mel_fmax,
140
- ):
141
- super(TacotronSTFT, self).__init__()
142
- self.n_mel_channels = n_mel_channels
143
- self.sampling_rate = sampling_rate
144
- self.stft_fn = STFT(filter_length, hop_length, win_length)
145
- mel_basis = librosa_mel_fn(
146
- sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
147
- )
148
- mel_basis = torch.from_numpy(mel_basis).float()
149
- self.register_buffer("mel_basis", mel_basis)
150
-
151
- def spectral_normalize(self, magnitudes, normalize_fun):
152
- output = dynamic_range_compression(magnitudes, normalize_fun)
153
- return output
154
-
155
- def spectral_de_normalize(self, magnitudes):
156
- output = dynamic_range_decompression(magnitudes)
157
- return output
158
-
159
- def mel_spectrogram(self, y, normalize_fun=torch.log):
160
- """Computes mel-spectrograms from a batch of waves
161
- PARAMS
162
- ------
163
- y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
164
-
165
- RETURNS
166
- -------
167
- mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
168
- """
169
- assert torch.min(y.data) >= -1, torch.min(y.data)
170
- assert torch.max(y.data) <= 1, torch.max(y.data)
171
-
172
- magnitudes, phases = self.stft_fn.transform(y)
173
- magnitudes = magnitudes.data
174
- mel_output = torch.matmul(self.mel_basis, magnitudes)
175
- mel_output = self.spectral_normalize(mel_output, normalize_fun)
176
- energy = torch.norm(magnitudes, dim=1)
177
-
178
- log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
179
-
180
- return mel_output, log_magnitudes, energy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/audio/tools.py DELETED
@@ -1,33 +0,0 @@
1
- import torch
2
- import numpy as np
3
-
4
-
5
- def get_mel_from_wav(audio, _stft):
6
- audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
7
- audio = torch.autograd.Variable(audio, requires_grad=False)
8
- melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
9
- melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
10
- log_magnitudes_stft = (
11
- torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
12
- )
13
- energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
14
- return melspec, log_magnitudes_stft, energy
15
-
16
-
17
- # def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
18
- # mel = torch.stack([mel])
19
- # mel_decompress = _stft.spectral_de_normalize(mel)
20
- # mel_decompress = mel_decompress.transpose(1, 2).data.cpu()
21
- # spec_from_mel_scaling = 1000
22
- # spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)
23
- # spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)
24
- # spec_from_mel = spec_from_mel * spec_from_mel_scaling
25
-
26
- # audio = griffin_lim(
27
- # torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters
28
- # )
29
-
30
- # audio = audio.squeeze()
31
- # audio = audio.cpu().numpy()
32
- # audio_path = out_filename
33
- # write(audio_path, _stft.sampling_rate, audio)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/__init__.py DELETED
File without changes
audioldm/clap/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (161 Bytes)
 
audioldm/clap/__pycache__/encoders.cpython-310.pyc DELETED
Binary file (5.14 kB)
 
audioldm/clap/encoders.py DELETED
@@ -1,170 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from audioldm.clap.open_clip import create_model
4
- from audioldm.clap.training.data import get_audio_features
5
- import torchaudio
6
- from transformers import RobertaTokenizer
7
- import torch.nn.functional as F
8
-
9
-
10
- class CLAPAudioEmbeddingClassifierFreev2(nn.Module):
11
- def __init__(
12
- self,
13
- pretrained_path="",
14
- key="class",
15
- sampling_rate=16000,
16
- embed_mode="audio",
17
- amodel = "HTSAT-tiny",
18
- unconditional_prob=0.1,
19
- random_mute=False,
20
- max_random_mute_portion=0.5,
21
- training_mode=True,
22
- ):
23
- super().__init__()
24
-
25
- self.key = key
26
- self.device = "cpu"
27
- self.precision = "fp32"
28
- self.amodel = amodel
29
- self.tmodel = "roberta" # the best text encoder in our training
30
- self.enable_fusion = False # False if you do not want to use the fusion model
31
- self.fusion_type = "aff_2d"
32
- self.pretrained = pretrained_path
33
- self.embed_mode = embed_mode
34
- self.embed_mode_orig = embed_mode
35
- self.sampling_rate = sampling_rate
36
- self.unconditional_prob = unconditional_prob
37
- self.random_mute = random_mute
38
- self.tokenize = RobertaTokenizer.from_pretrained("roberta-base")
39
- self.max_random_mute_portion = max_random_mute_portion
40
- self.training_mode = training_mode
41
- self.model, self.model_cfg = create_model(
42
- self.amodel,
43
- self.tmodel,
44
- self.pretrained,
45
- precision=self.precision,
46
- device=self.device,
47
- enable_fusion=self.enable_fusion,
48
- fusion_type=self.fusion_type,
49
- )
50
- for p in self.model.parameters():
51
- p.requires_grad = False
52
-
53
- self.model.eval()
54
-
55
- def get_unconditional_condition(self, batchsize):
56
- self.unconditional_token = self.model.get_text_embedding(
57
- self.tokenizer(["", ""])
58
- )[0:1]
59
- return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0)
60
-
61
- def batch_to_list(self, batch):
62
- ret = []
63
- for i in range(batch.size(0)):
64
- ret.append(batch[i])
65
- return ret
66
-
67
- def make_decision(self, probability):
68
- if float(torch.rand(1)) < probability:
69
- return True
70
- else:
71
- return False
72
-
73
- def random_uniform(self, start, end):
74
- val = torch.rand(1).item()
75
- return start + (end - start) * val
76
-
77
- def _random_mute(self, waveform):
78
- # waveform: [bs, t-steps]
79
- t_steps = waveform.size(-1)
80
- for i in range(waveform.size(0)):
81
- mute_size = int(
82
- self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion))
83
- )
84
- mute_start = int(self.random_uniform(0, t_steps - mute_size))
85
- waveform[i, mute_start : mute_start + mute_size] = 0
86
- return waveform
87
-
88
- def cos_similarity(self, waveform, text):
89
- # waveform: [bs, t_steps]
90
- with torch.no_grad():
91
- self.embed_mode = "audio"
92
- audio_emb = self(waveform.cuda())
93
- self.embed_mode = "text"
94
- text_emb = self(text)
95
- similarity = F.cosine_similarity(audio_emb, text_emb, dim=2)
96
- return similarity.squeeze()
97
-
98
- def forward(self, batch, key=None):
99
- # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0
100
- # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0
101
- if self.model.training == True and not self.training_mode:
102
- print(
103
- "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters."
104
- )
105
- self.model, self.model_cfg = create_model(
106
- self.amodel,
107
- self.tmodel,
108
- self.pretrained,
109
- precision=self.precision,
110
- device="cuda",
111
- enable_fusion=self.enable_fusion,
112
- fusion_type=self.fusion_type,
113
- )
114
- for p in self.model.parameters():
115
- p.requires_grad = False
116
- self.model.eval()
117
-
118
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
119
- if self.embed_mode == "audio":
120
- with torch.no_grad():
121
- audio_dict_list = []
122
- assert (
123
- self.sampling_rate == 16000
124
- ), "We only support 16000 sampling rate"
125
- if self.random_mute:
126
- batch = self._random_mute(batch)
127
- # batch: [bs, 1, t-samples]
128
- batch = torchaudio.functional.resample(
129
- batch, orig_freq=self.sampling_rate, new_freq=48000
130
- )
131
- for waveform in self.batch_to_list(batch):
132
- audio_dict = {}
133
- audio_dict = get_audio_features(
134
- audio_dict,
135
- waveform,
136
- 480000,
137
- data_truncating="fusion",
138
- data_filling="repeatpad",
139
- audio_cfg=self.model_cfg["audio_cfg"],
140
- )
141
- audio_dict_list.append(audio_dict)
142
- # [bs, 512]
143
- embed = self.model.get_audio_embedding(audio_dict_list)
144
- elif self.embed_mode == "text":
145
- with torch.no_grad():
146
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
147
- text_data = self.tokenizer(batch)
148
- embed = self.model.get_text_embedding(text_data)
149
-
150
- embed = embed.unsqueeze(1)
151
- self.unconditional_token = self.model.get_text_embedding(
152
- self.tokenizer(["", ""])
153
- )[0:1]
154
-
155
- for i in range(embed.size(0)):
156
- if self.make_decision(self.unconditional_prob):
157
- embed[i] = self.unconditional_token
158
-
159
- # [bs, 1, 512]
160
- return embed.detach()
161
-
162
- def tokenizer(self, text):
163
- result = self.tokenize(
164
- text,
165
- padding="max_length",
166
- truncation=True,
167
- max_length=512,
168
- return_tensors="pt",
169
- )
170
- return {k: v.squeeze(0) for k, v in result.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (967 Bytes)
 
audioldm/clap/open_clip/__pycache__/factory.cpython-310.pyc DELETED
Binary file (6.65 kB)
 
audioldm/clap/open_clip/__pycache__/feature_fusion.cpython-310.pyc DELETED
Binary file (4.12 kB)
 
audioldm/clap/open_clip/__pycache__/htsat.cpython-310.pyc DELETED
Binary file (30.8 kB)
 
audioldm/clap/open_clip/__pycache__/loss.cpython-310.pyc DELETED
Binary file (7.98 kB)
 
audioldm/clap/open_clip/__pycache__/model.cpython-310.pyc DELETED
Binary file (24.2 kB)
 
audioldm/clap/open_clip/__pycache__/openai.cpython-310.pyc DELETED
Binary file (4.53 kB)
 
audioldm/clap/open_clip/__pycache__/pann_model.cpython-310.pyc DELETED
Binary file (13.1 kB)
 
audioldm/clap/open_clip/__pycache__/pretrained.cpython-310.pyc DELETED
Binary file (5.04 kB)
 
audioldm/clap/open_clip/__pycache__/timm_model.cpython-310.pyc DELETED
Binary file (3.44 kB)
 
audioldm/clap/open_clip/__pycache__/tokenizer.cpython-310.pyc DELETED
Binary file (7.36 kB)
 
audioldm/clap/open_clip/__pycache__/transform.cpython-310.pyc DELETED
Binary file (985 Bytes)
 
audioldm/clap/open_clip/__pycache__/utils.cpython-310.pyc DELETED
Binary file (9.88 kB)
 
audioldm/clap/open_clip/bert.py DELETED
@@ -1,40 +0,0 @@
1
- from transformers import BertTokenizer, BertModel
2
-
3
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
4
- model = BertModel.from_pretrained("bert-base-uncased")
5
- text = "Replace me by any text you'd like."
6
-
7
-
8
- def bert_embeddings(text):
9
- # text = "Replace me by any text you'd like."
10
- encoded_input = tokenizer(text, return_tensors="pt")
11
- output = model(**encoded_input)
12
- return output
13
-
14
-
15
- from transformers import RobertaTokenizer, RobertaModel
16
-
17
- tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
18
- model = RobertaModel.from_pretrained("roberta-base")
19
- text = "Replace me by any text you'd like."
20
-
21
-
22
- def Roberta_embeddings(text):
23
- # text = "Replace me by any text you'd like."
24
- encoded_input = tokenizer(text, return_tensors="pt")
25
- output = model(**encoded_input)
26
- return output
27
-
28
-
29
- from transformers import BartTokenizer, BartModel
30
-
31
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
32
- model = BartModel.from_pretrained("facebook/bart-base")
33
- text = "Replace me by any text you'd like."
34
-
35
-
36
- def bart_embeddings(text):
37
- # text = "Replace me by any text you'd like."
38
- encoded_input = tokenizer(text, return_tensors="pt")
39
- output = model(**encoded_input)
40
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
audioldm/clap/open_clip/factory.py DELETED
@@ -1,277 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import pathlib
5
- import re
6
- from copy import deepcopy
7
- from pathlib import Path
8
-
9
- import torch
10
-
11
- from .model import CLAP, convert_weights_to_fp16
12
- from .openai import load_openai_model
13
- from .pretrained import get_pretrained_url, download_pretrained
14
- from .transform import image_transform
15
-
16
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
17
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
18
-
19
-
20
- def _natural_key(string_):
21
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
22
-
23
-
24
- def _rescan_model_configs():
25
- global _MODEL_CONFIGS
26
-
27
- config_ext = (".json",)
28
- config_files = []
29
- for config_path in _MODEL_CONFIG_PATHS:
30
- if config_path.is_file() and config_path.suffix in config_ext:
31
- config_files.append(config_path)
32
- elif config_path.is_dir():
33
- for ext in config_ext:
34
- config_files.extend(config_path.glob(f"*{ext}"))
35
-
36
- for cf in config_files:
37
- if os.path.basename(cf)[0] == ".":
38
- continue # Ignore hidden files
39
-
40
- with open(cf, "r") as f:
41
- model_cfg = json.load(f)
42
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
43
- _MODEL_CONFIGS[cf.stem] = model_cfg
44
-
45
- _MODEL_CONFIGS = {
46
- k: v
47
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
48
- }
49
-
50
-
51
- _rescan_model_configs() # initial populate of model config registry
52
-
53
-
54
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
55
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
56
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
57
- state_dict = checkpoint["state_dict"]
58
- else:
59
- state_dict = checkpoint
60
- if skip_params:
61
- if next(iter(state_dict.items()))[0].startswith("module"):
62
- state_dict = {k[7:]: v for k, v in state_dict.items()}
63
- # for k in state_dict:
64
- # if k.startswith('transformer'):
65
- # v = state_dict.pop(k)
66
- # state_dict['text_branch.' + k[12:]] = v
67
- return state_dict
68
-
69
-
70
- def create_model(
71
- amodel_name: str,
72
- tmodel_name: str,
73
- pretrained: str = "",
74
- precision: str = "fp32",
75
- device: torch.device = torch.device("cpu"),
76
- jit: bool = False,
77
- force_quick_gelu: bool = False,
78
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
79
- skip_params=True,
80
- pretrained_audio: str = "",
81
- pretrained_text: str = "",
82
- enable_fusion: bool = False,
83
- fusion_type: str = "None"
84
- # pretrained_image: bool = False,
85
- ):
86
- amodel_name = amodel_name.replace(
87
- "/", "-"
88
- ) # for callers using old naming with / in ViT names
89
- pretrained_orig = pretrained
90
- pretrained = pretrained.lower()
91
- if pretrained == "openai":
92
- if amodel_name in _MODEL_CONFIGS:
93
- logging.info(f"Loading {amodel_name} model config.")
94
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
95
- else:
96
- logging.error(
97
- f"Model config for {amodel_name} not found; available models {list_models()}."
98
- )
99
- raise RuntimeError(f"Model config for {amodel_name} not found.")
100
-
101
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
102
- # Hard Code in model name
103
- model_cfg["text_cfg"]["model_type"] = tmodel_name
104
- model = load_openai_model(
105
- "ViT-B-16",
106
- model_cfg,
107
- device=device,
108
- jit=jit,
109
- cache_dir=openai_model_cache_dir,
110
- enable_fusion=enable_fusion,
111
- fusion_type=fusion_type,
112
- )
113
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
114
- if precision == "amp" or precision == "fp32":
115
- model = model.float()
116
- else:
117
- if amodel_name in _MODEL_CONFIGS:
118
- logging.info(f"Loading {amodel_name} model config.")
119
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
120
- else:
121
- logging.error(
122
- f"Model config for {amodel_name} not found; available models {list_models()}."
123
- )
124
- raise RuntimeError(f"Model config for {amodel_name} not found.")
125
-
126
- if force_quick_gelu:
127
- # override for use of QuickGELU on non-OpenAI transformer models
128
- model_cfg["quick_gelu"] = True
129
-
130
- # if pretrained_image:
131
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
132
- # # pretrained weight loading for timm models set via vision_cfg
133
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
134
- # else:
135
- # assert False, 'pretrained image towers currently only supported for timm models'
136
- model_cfg["text_cfg"]["model_type"] = tmodel_name
137
- model_cfg["enable_fusion"] = enable_fusion
138
- model_cfg["fusion_type"] = fusion_type
139
- model = CLAP(**model_cfg)
140
-
141
- if pretrained:
142
- checkpoint_path = ""
143
- url = get_pretrained_url(amodel_name, pretrained)
144
- if url:
145
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
146
- elif os.path.exists(pretrained_orig):
147
- checkpoint_path = pretrained_orig
148
- if checkpoint_path:
149
- logging.info(
150
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
151
- )
152
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
153
- model.load_state_dict(ckpt)
154
- param_names = [n for n, p in model.named_parameters()]
155
- # for n in param_names:
156
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
157
- else:
158
- logging.warning(
159
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
160
- )
161
- raise RuntimeError(
162
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
163
- )
164
-
165
- if pretrained_audio:
166
- if amodel_name.startswith("PANN"):
167
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
168
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
169
- audio_ckpt = audio_ckpt["model"]
170
- keys = list(audio_ckpt.keys())
171
- for key in keys:
172
- if (
173
- "spectrogram_extractor" not in key
174
- and "logmel_extractor" not in key
175
- ):
176
- v = audio_ckpt.pop(key)
177
- audio_ckpt["audio_branch." + key] = v
178
- elif os.path.basename(pretrained_audio).startswith(
179
- "PANN"
180
- ): # checkpoint trained via HTSAT codebase
181
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
182
- audio_ckpt = audio_ckpt["state_dict"]
183
- keys = list(audio_ckpt.keys())
184
- for key in keys:
185
- if key.startswith("sed_model"):
186
- v = audio_ckpt.pop(key)
187
- audio_ckpt["audio_branch." + key[10:]] = v
188
- elif os.path.basename(pretrained_audio).startswith(
189
- "finetuned"
190
- ): # checkpoint trained via linear probe codebase
191
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
192
- else:
193
- raise ValueError("Unknown audio checkpoint")
194
- elif amodel_name.startswith("HTSAT"):
195
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
196
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
197
- audio_ckpt = audio_ckpt["state_dict"]
198
- keys = list(audio_ckpt.keys())
199
- for key in keys:
200
- if key.startswith("sed_model") and (
201
- "spectrogram_extractor" not in key
202
- and "logmel_extractor" not in key
203
- ):
204
- v = audio_ckpt.pop(key)
205
- audio_ckpt["audio_branch." + key[10:]] = v
206
- elif os.path.basename(pretrained_audio).startswith(
207
- "HTSAT"
208
- ): # checkpoint trained via HTSAT codebase
209
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
210
- audio_ckpt = audio_ckpt["state_dict"]
211
- keys = list(audio_ckpt.keys())
212
- for key in keys:
213
- if key.startswith("sed_model"):
214
- v = audio_ckpt.pop(key)
215
- audio_ckpt["audio_branch." + key[10:]] = v
216
- elif os.path.basename(pretrained_audio).startswith(
217
- "finetuned"
218
- ): # checkpoint trained via linear probe codebase
219
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
220
- else:
221
- raise ValueError("Unknown audio checkpoint")
222
- else:
223
- raise f"this audio encoder pretrained checkpoint is not support"
224
-
225
- model.load_state_dict(audio_ckpt, strict=False)
226
- logging.info(
227
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
228
- )
229
- param_names = [n for n, p in model.named_parameters()]
230
- for n in param_names:
231
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
232
-
233
- model.to(device=device)
234
- if precision == "fp16":
235
- assert device.type != "cpu"
236
- convert_weights_to_fp16(model)
237
-
238
- if jit:
239
- model = torch.jit.script(model)
240
-
241
- return model, model_cfg
242
-
243
-
244
- def create_model_and_transforms(
245
- model_name: str,
246
- pretrained: str = "",
247
- precision: str = "fp32",
248
- device: torch.device = torch.device("cpu"),
249
- jit: bool = False,
250
- force_quick_gelu: bool = False,
251
- # pretrained_image: bool = False,
252
- ):
253
- model = create_model(
254
- model_name,
255
- pretrained,
256
- precision,
257
- device,
258
- jit,
259
- force_quick_gelu=force_quick_gelu,
260
- # pretrained_image=pretrained_image
261
- )
262
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
263
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
264
- return model, preprocess_train, preprocess_val
265
-
266
-
267
- def list_models():
268
- """enumerate available model architectures based on config files"""
269
- return list(_MODEL_CONFIGS.keys())
270
-
271
-
272
- def add_model_config(path):
273
- """add model config path or file and update registry"""
274
- if not isinstance(path, Path):
275
- path = Path(path)
276
- _MODEL_CONFIG_PATHS.append(path)
277
- _rescan_model_configs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/feature_fusion.py DELETED
@@ -1,192 +0,0 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/htsat.py DELETED
@@ -1,1308 +0,0 @@
1
- # Ke Chen
2
3
- # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
- # Some layers designed on the model
5
- # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
- # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from itertools import repeat
12
- import collections.abc
13
- import math
14
- import warnings
15
-
16
- from torch.nn.init import _calculate_fan_in_and_fan_out
17
- import torch.utils.checkpoint as checkpoint
18
-
19
- import random
20
-
21
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
22
- from torchlibrosa.augmentation import SpecAugmentation
23
-
24
- from itertools import repeat
25
- from .utils import do_mixup, interpolate
26
-
27
- from .feature_fusion import iAFF, AFF, DAF
28
-
29
- # from PyTorch internals
30
- def _ntuple(n):
31
- def parse(x):
32
- if isinstance(x, collections.abc.Iterable):
33
- return x
34
- return tuple(repeat(x, n))
35
-
36
- return parse
37
-
38
-
39
- to_1tuple = _ntuple(1)
40
- to_2tuple = _ntuple(2)
41
- to_3tuple = _ntuple(3)
42
- to_4tuple = _ntuple(4)
43
- to_ntuple = _ntuple
44
-
45
-
46
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
- 'survival rate' as the argument.
53
- """
54
- if drop_prob == 0.0 or not training:
55
- return x
56
- keep_prob = 1 - drop_prob
57
- shape = (x.shape[0],) + (1,) * (
58
- x.ndim - 1
59
- ) # work with diff dim tensors, not just 2D ConvNets
60
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
- random_tensor.floor_() # binarize
62
- output = x.div(keep_prob) * random_tensor
63
- return output
64
-
65
-
66
- class DropPath(nn.Module):
67
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
-
69
- def __init__(self, drop_prob=None):
70
- super(DropPath, self).__init__()
71
- self.drop_prob = drop_prob
72
-
73
- def forward(self, x):
74
- return drop_path(x, self.drop_prob, self.training)
75
-
76
-
77
- class PatchEmbed(nn.Module):
78
- """2D Image to Patch Embedding"""
79
-
80
- def __init__(
81
- self,
82
- img_size=224,
83
- patch_size=16,
84
- in_chans=3,
85
- embed_dim=768,
86
- norm_layer=None,
87
- flatten=True,
88
- patch_stride=16,
89
- enable_fusion=False,
90
- fusion_type="None",
91
- ):
92
- super().__init__()
93
- img_size = to_2tuple(img_size)
94
- patch_size = to_2tuple(patch_size)
95
- patch_stride = to_2tuple(patch_stride)
96
- self.img_size = img_size
97
- self.patch_size = patch_size
98
- self.patch_stride = patch_stride
99
- self.grid_size = (
100
- img_size[0] // patch_stride[0],
101
- img_size[1] // patch_stride[1],
102
- )
103
- self.num_patches = self.grid_size[0] * self.grid_size[1]
104
- self.flatten = flatten
105
- self.in_chans = in_chans
106
- self.embed_dim = embed_dim
107
-
108
- self.enable_fusion = enable_fusion
109
- self.fusion_type = fusion_type
110
-
111
- padding = (
112
- (patch_size[0] - patch_stride[0]) // 2,
113
- (patch_size[1] - patch_stride[1]) // 2,
114
- )
115
-
116
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
- self.proj = nn.Conv2d(
118
- in_chans * 4,
119
- embed_dim,
120
- kernel_size=patch_size,
121
- stride=patch_stride,
122
- padding=padding,
123
- )
124
- else:
125
- self.proj = nn.Conv2d(
126
- in_chans,
127
- embed_dim,
128
- kernel_size=patch_size,
129
- stride=patch_stride,
130
- padding=padding,
131
- )
132
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
-
134
- if (self.enable_fusion) and (
135
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
- ):
137
- self.mel_conv2d = nn.Conv2d(
138
- in_chans,
139
- embed_dim,
140
- kernel_size=(patch_size[0], patch_size[1] * 3),
141
- stride=(patch_stride[0], patch_stride[1] * 3),
142
- padding=padding,
143
- )
144
- if self.fusion_type == "daf_2d":
145
- self.fusion_model = DAF()
146
- elif self.fusion_type == "aff_2d":
147
- self.fusion_model = AFF(channels=embed_dim, type="2D")
148
- elif self.fusion_type == "iaff_2d":
149
- self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
-
151
- def forward(self, x, longer_idx=None):
152
- if (self.enable_fusion) and (
153
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
- ):
155
- global_x = x[:, 0:1, :, :]
156
-
157
- # global processing
158
- B, C, H, W = global_x.shape
159
- assert (
160
- H == self.img_size[0] and W == self.img_size[1]
161
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
- global_x = self.proj(global_x)
163
- TW = global_x.size(-1)
164
- if len(longer_idx) > 0:
165
- # local processing
166
- local_x = x[longer_idx, 1:, :, :].contiguous()
167
- B, C, H, W = local_x.shape
168
- local_x = local_x.view(B * C, 1, H, W)
169
- local_x = self.mel_conv2d(local_x)
170
- local_x = local_x.view(
171
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
- )
173
- local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
- TB, TC, TH, _ = local_x.size()
175
- if local_x.size(-1) < TW:
176
- local_x = torch.cat(
177
- [
178
- local_x,
179
- torch.zeros(
180
- (TB, TC, TH, TW - local_x.size(-1)),
181
- device=global_x.device,
182
- ),
183
- ],
184
- dim=-1,
185
- )
186
- else:
187
- local_x = local_x[:, :, :, :TW]
188
-
189
- global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
- x = global_x
191
- else:
192
- B, C, H, W = x.shape
193
- assert (
194
- H == self.img_size[0] and W == self.img_size[1]
195
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
- x = self.proj(x)
197
-
198
- if self.flatten:
199
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
- x = self.norm(x)
201
- return x
202
-
203
-
204
- class Mlp(nn.Module):
205
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
-
207
- def __init__(
208
- self,
209
- in_features,
210
- hidden_features=None,
211
- out_features=None,
212
- act_layer=nn.GELU,
213
- drop=0.0,
214
- ):
215
- super().__init__()
216
- out_features = out_features or in_features
217
- hidden_features = hidden_features or in_features
218
- self.fc1 = nn.Linear(in_features, hidden_features)
219
- self.act = act_layer()
220
- self.fc2 = nn.Linear(hidden_features, out_features)
221
- self.drop = nn.Dropout(drop)
222
-
223
- def forward(self, x):
224
- x = self.fc1(x)
225
- x = self.act(x)
226
- x = self.drop(x)
227
- x = self.fc2(x)
228
- x = self.drop(x)
229
- return x
230
-
231
-
232
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
- def norm_cdf(x):
236
- # Computes standard normal cumulative distribution function
237
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
-
239
- if (mean < a - 2 * std) or (mean > b + 2 * std):
240
- warnings.warn(
241
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
- "The distribution of values may be incorrect.",
243
- stacklevel=2,
244
- )
245
-
246
- with torch.no_grad():
247
- # Values are generated by using a truncated uniform distribution and
248
- # then using the inverse CDF for the normal distribution.
249
- # Get upper and lower cdf values
250
- l = norm_cdf((a - mean) / std)
251
- u = norm_cdf((b - mean) / std)
252
-
253
- # Uniformly fill tensor with values from [l, u], then translate to
254
- # [2l-1, 2u-1].
255
- tensor.uniform_(2 * l - 1, 2 * u - 1)
256
-
257
- # Use inverse cdf transform for normal distribution to get truncated
258
- # standard normal
259
- tensor.erfinv_()
260
-
261
- # Transform to proper mean, std
262
- tensor.mul_(std * math.sqrt(2.0))
263
- tensor.add_(mean)
264
-
265
- # Clamp to ensure it's in the proper range
266
- tensor.clamp_(min=a, max=b)
267
- return tensor
268
-
269
-
270
- def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
- # type: (Tensor, float, float, float, float) -> Tensor
272
- r"""Fills the input Tensor with values drawn from a truncated
273
- normal distribution. The values are effectively drawn from the
274
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
- with values outside :math:`[a, b]` redrawn until they are within
276
- the bounds. The method used for generating the random values works
277
- best when :math:`a \leq \text{mean} \leq b`.
278
- Args:
279
- tensor: an n-dimensional `torch.Tensor`
280
- mean: the mean of the normal distribution
281
- std: the standard deviation of the normal distribution
282
- a: the minimum cutoff value
283
- b: the maximum cutoff value
284
- Examples:
285
- >>> w = torch.empty(3, 5)
286
- >>> nn.init.trunc_normal_(w)
287
- """
288
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
-
290
-
291
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
- if mode == "fan_in":
294
- denom = fan_in
295
- elif mode == "fan_out":
296
- denom = fan_out
297
- elif mode == "fan_avg":
298
- denom = (fan_in + fan_out) / 2
299
-
300
- variance = scale / denom
301
-
302
- if distribution == "truncated_normal":
303
- # constant is stddev of standard normal truncated to (-2, 2)
304
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
- elif distribution == "normal":
306
- tensor.normal_(std=math.sqrt(variance))
307
- elif distribution == "uniform":
308
- bound = math.sqrt(3 * variance)
309
- tensor.uniform_(-bound, bound)
310
- else:
311
- raise ValueError(f"invalid distribution {distribution}")
312
-
313
-
314
- def lecun_normal_(tensor):
315
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
-
317
-
318
- def window_partition(x, window_size):
319
- """
320
- Args:
321
- x: (B, H, W, C)
322
- window_size (int): window size
323
- Returns:
324
- windows: (num_windows*B, window_size, window_size, C)
325
- """
326
- B, H, W, C = x.shape
327
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
- windows = (
329
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
- )
331
- return windows
332
-
333
-
334
- def window_reverse(windows, window_size, H, W):
335
- """
336
- Args:
337
- windows: (num_windows*B, window_size, window_size, C)
338
- window_size (int): Window size
339
- H (int): Height of image
340
- W (int): Width of image
341
- Returns:
342
- x: (B, H, W, C)
343
- """
344
- B = int(windows.shape[0] / (H * W / window_size / window_size))
345
- x = windows.view(
346
- B, H // window_size, W // window_size, window_size, window_size, -1
347
- )
348
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
- return x
350
-
351
-
352
- class WindowAttention(nn.Module):
353
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
- It supports both of shifted and non-shifted window.
355
- Args:
356
- dim (int): Number of input channels.
357
- window_size (tuple[int]): The height and width of the window.
358
- num_heads (int): Number of attention heads.
359
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
- """
364
-
365
- def __init__(
366
- self,
367
- dim,
368
- window_size,
369
- num_heads,
370
- qkv_bias=True,
371
- qk_scale=None,
372
- attn_drop=0.0,
373
- proj_drop=0.0,
374
- ):
375
-
376
- super().__init__()
377
- self.dim = dim
378
- self.window_size = window_size # Wh, Ww
379
- self.num_heads = num_heads
380
- head_dim = dim // num_heads
381
- self.scale = qk_scale or head_dim**-0.5
382
-
383
- # define a parameter table of relative position bias
384
- self.relative_position_bias_table = nn.Parameter(
385
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
386
- ) # 2*Wh-1 * 2*Ww-1, nH
387
-
388
- # get pair-wise relative position index for each token inside the window
389
- coords_h = torch.arange(self.window_size[0])
390
- coords_w = torch.arange(self.window_size[1])
391
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
392
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
393
- relative_coords = (
394
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
395
- ) # 2, Wh*Ww, Wh*Ww
396
- relative_coords = relative_coords.permute(
397
- 1, 2, 0
398
- ).contiguous() # Wh*Ww, Wh*Ww, 2
399
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
400
- relative_coords[:, :, 1] += self.window_size[1] - 1
401
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
402
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
403
- self.register_buffer("relative_position_index", relative_position_index)
404
-
405
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
406
- self.attn_drop = nn.Dropout(attn_drop)
407
- self.proj = nn.Linear(dim, dim)
408
- self.proj_drop = nn.Dropout(proj_drop)
409
-
410
- trunc_normal_(self.relative_position_bias_table, std=0.02)
411
- self.softmax = nn.Softmax(dim=-1)
412
-
413
- def forward(self, x, mask=None):
414
- """
415
- Args:
416
- x: input features with shape of (num_windows*B, N, C)
417
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
418
- """
419
- B_, N, C = x.shape
420
- qkv = (
421
- self.qkv(x)
422
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
423
- .permute(2, 0, 3, 1, 4)
424
- )
425
- q, k, v = (
426
- qkv[0],
427
- qkv[1],
428
- qkv[2],
429
- ) # make torchscript happy (cannot use tensor as tuple)
430
-
431
- q = q * self.scale
432
- attn = q @ k.transpose(-2, -1)
433
-
434
- relative_position_bias = self.relative_position_bias_table[
435
- self.relative_position_index.view(-1)
436
- ].view(
437
- self.window_size[0] * self.window_size[1],
438
- self.window_size[0] * self.window_size[1],
439
- -1,
440
- ) # Wh*Ww,Wh*Ww,nH
441
- relative_position_bias = relative_position_bias.permute(
442
- 2, 0, 1
443
- ).contiguous() # nH, Wh*Ww, Wh*Ww
444
- attn = attn + relative_position_bias.unsqueeze(0)
445
-
446
- if mask is not None:
447
- nW = mask.shape[0]
448
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
449
- 1
450
- ).unsqueeze(0)
451
- attn = attn.view(-1, self.num_heads, N, N)
452
- attn = self.softmax(attn)
453
- else:
454
- attn = self.softmax(attn)
455
-
456
- attn = self.attn_drop(attn)
457
-
458
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
459
- x = self.proj(x)
460
- x = self.proj_drop(x)
461
- return x, attn
462
-
463
- def extra_repr(self):
464
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
465
-
466
-
467
- # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
468
- class SwinTransformerBlock(nn.Module):
469
- r"""Swin Transformer Block.
470
- Args:
471
- dim (int): Number of input channels.
472
- input_resolution (tuple[int]): Input resulotion.
473
- num_heads (int): Number of attention heads.
474
- window_size (int): Window size.
475
- shift_size (int): Shift size for SW-MSA.
476
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
477
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
478
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
479
- drop (float, optional): Dropout rate. Default: 0.0
480
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
481
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
482
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
483
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
484
- """
485
-
486
- def __init__(
487
- self,
488
- dim,
489
- input_resolution,
490
- num_heads,
491
- window_size=7,
492
- shift_size=0,
493
- mlp_ratio=4.0,
494
- qkv_bias=True,
495
- qk_scale=None,
496
- drop=0.0,
497
- attn_drop=0.0,
498
- drop_path=0.0,
499
- act_layer=nn.GELU,
500
- norm_layer=nn.LayerNorm,
501
- norm_before_mlp="ln",
502
- ):
503
- super().__init__()
504
- self.dim = dim
505
- self.input_resolution = input_resolution
506
- self.num_heads = num_heads
507
- self.window_size = window_size
508
- self.shift_size = shift_size
509
- self.mlp_ratio = mlp_ratio
510
- self.norm_before_mlp = norm_before_mlp
511
- if min(self.input_resolution) <= self.window_size:
512
- # if window size is larger than input resolution, we don't partition windows
513
- self.shift_size = 0
514
- self.window_size = min(self.input_resolution)
515
- assert (
516
- 0 <= self.shift_size < self.window_size
517
- ), "shift_size must in 0-window_size"
518
-
519
- self.norm1 = norm_layer(dim)
520
- self.attn = WindowAttention(
521
- dim,
522
- window_size=to_2tuple(self.window_size),
523
- num_heads=num_heads,
524
- qkv_bias=qkv_bias,
525
- qk_scale=qk_scale,
526
- attn_drop=attn_drop,
527
- proj_drop=drop,
528
- )
529
-
530
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
531
- if self.norm_before_mlp == "ln":
532
- self.norm2 = nn.LayerNorm(dim)
533
- elif self.norm_before_mlp == "bn":
534
- self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
535
- 1, 2
536
- )
537
- else:
538
- raise NotImplementedError
539
- mlp_hidden_dim = int(dim * mlp_ratio)
540
- self.mlp = Mlp(
541
- in_features=dim,
542
- hidden_features=mlp_hidden_dim,
543
- act_layer=act_layer,
544
- drop=drop,
545
- )
546
-
547
- if self.shift_size > 0:
548
- # calculate attention mask for SW-MSA
549
- H, W = self.input_resolution
550
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
551
- h_slices = (
552
- slice(0, -self.window_size),
553
- slice(-self.window_size, -self.shift_size),
554
- slice(-self.shift_size, None),
555
- )
556
- w_slices = (
557
- slice(0, -self.window_size),
558
- slice(-self.window_size, -self.shift_size),
559
- slice(-self.shift_size, None),
560
- )
561
- cnt = 0
562
- for h in h_slices:
563
- for w in w_slices:
564
- img_mask[:, h, w, :] = cnt
565
- cnt += 1
566
-
567
- mask_windows = window_partition(
568
- img_mask, self.window_size
569
- ) # nW, window_size, window_size, 1
570
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
571
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
572
- attn_mask = attn_mask.masked_fill(
573
- attn_mask != 0, float(-100.0)
574
- ).masked_fill(attn_mask == 0, float(0.0))
575
- else:
576
- attn_mask = None
577
-
578
- self.register_buffer("attn_mask", attn_mask)
579
-
580
- def forward(self, x):
581
- # pdb.set_trace()
582
- H, W = self.input_resolution
583
- # print("H: ", H)
584
- # print("W: ", W)
585
- # pdb.set_trace()
586
- B, L, C = x.shape
587
- # assert L == H * W, "input feature has wrong size"
588
-
589
- shortcut = x
590
- x = self.norm1(x)
591
- x = x.view(B, H, W, C)
592
-
593
- # cyclic shift
594
- if self.shift_size > 0:
595
- shifted_x = torch.roll(
596
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
597
- )
598
- else:
599
- shifted_x = x
600
-
601
- # partition windows
602
- x_windows = window_partition(
603
- shifted_x, self.window_size
604
- ) # nW*B, window_size, window_size, C
605
- x_windows = x_windows.view(
606
- -1, self.window_size * self.window_size, C
607
- ) # nW*B, window_size*window_size, C
608
-
609
- # W-MSA/SW-MSA
610
- attn_windows, attn = self.attn(
611
- x_windows, mask=self.attn_mask
612
- ) # nW*B, window_size*window_size, C
613
-
614
- # merge windows
615
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
616
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
617
-
618
- # reverse cyclic shift
619
- if self.shift_size > 0:
620
- x = torch.roll(
621
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
622
- )
623
- else:
624
- x = shifted_x
625
- x = x.view(B, H * W, C)
626
-
627
- # FFN
628
- x = shortcut + self.drop_path(x)
629
- x = x + self.drop_path(self.mlp(self.norm2(x)))
630
-
631
- return x, attn
632
-
633
- def extra_repr(self):
634
- return (
635
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
636
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
637
- )
638
-
639
-
640
- class PatchMerging(nn.Module):
641
- r"""Patch Merging Layer.
642
- Args:
643
- input_resolution (tuple[int]): Resolution of input feature.
644
- dim (int): Number of input channels.
645
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
646
- """
647
-
648
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
649
- super().__init__()
650
- self.input_resolution = input_resolution
651
- self.dim = dim
652
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
653
- self.norm = norm_layer(4 * dim)
654
-
655
- def forward(self, x):
656
- """
657
- x: B, H*W, C
658
- """
659
- H, W = self.input_resolution
660
- B, L, C = x.shape
661
- assert L == H * W, "input feature has wrong size"
662
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
663
-
664
- x = x.view(B, H, W, C)
665
-
666
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
667
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
668
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
669
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
670
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
671
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
672
-
673
- x = self.norm(x)
674
- x = self.reduction(x)
675
-
676
- return x
677
-
678
- def extra_repr(self):
679
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
680
-
681
-
682
- class BasicLayer(nn.Module):
683
- """A basic Swin Transformer layer for one stage.
684
- Args:
685
- dim (int): Number of input channels.
686
- input_resolution (tuple[int]): Input resolution.
687
- depth (int): Number of blocks.
688
- num_heads (int): Number of attention heads.
689
- window_size (int): Local window size.
690
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
691
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
692
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
693
- drop (float, optional): Dropout rate. Default: 0.0
694
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
695
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
696
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
697
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
698
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
699
- """
700
-
701
- def __init__(
702
- self,
703
- dim,
704
- input_resolution,
705
- depth,
706
- num_heads,
707
- window_size,
708
- mlp_ratio=4.0,
709
- qkv_bias=True,
710
- qk_scale=None,
711
- drop=0.0,
712
- attn_drop=0.0,
713
- drop_path=0.0,
714
- norm_layer=nn.LayerNorm,
715
- downsample=None,
716
- use_checkpoint=False,
717
- norm_before_mlp="ln",
718
- ):
719
-
720
- super().__init__()
721
- self.dim = dim
722
- self.input_resolution = input_resolution
723
- self.depth = depth
724
- self.use_checkpoint = use_checkpoint
725
-
726
- # build blocks
727
- self.blocks = nn.ModuleList(
728
- [
729
- SwinTransformerBlock(
730
- dim=dim,
731
- input_resolution=input_resolution,
732
- num_heads=num_heads,
733
- window_size=window_size,
734
- shift_size=0 if (i % 2 == 0) else window_size // 2,
735
- mlp_ratio=mlp_ratio,
736
- qkv_bias=qkv_bias,
737
- qk_scale=qk_scale,
738
- drop=drop,
739
- attn_drop=attn_drop,
740
- drop_path=drop_path[i]
741
- if isinstance(drop_path, list)
742
- else drop_path,
743
- norm_layer=norm_layer,
744
- norm_before_mlp=norm_before_mlp,
745
- )
746
- for i in range(depth)
747
- ]
748
- )
749
-
750
- # patch merging layer
751
- if downsample is not None:
752
- self.downsample = downsample(
753
- input_resolution, dim=dim, norm_layer=norm_layer
754
- )
755
- else:
756
- self.downsample = None
757
-
758
- def forward(self, x):
759
- attns = []
760
- for blk in self.blocks:
761
- if self.use_checkpoint:
762
- x = checkpoint.checkpoint(blk, x)
763
- else:
764
- x, attn = blk(x)
765
- if not self.training:
766
- attns.append(attn.unsqueeze(0))
767
- if self.downsample is not None:
768
- x = self.downsample(x)
769
- if not self.training:
770
- attn = torch.cat(attns, dim=0)
771
- attn = torch.mean(attn, dim=0)
772
- return x, attn
773
-
774
- def extra_repr(self):
775
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
776
-
777
-
778
- # The Core of HTSAT
779
- class HTSAT_Swin_Transformer(nn.Module):
780
- r"""HTSAT based on the Swin Transformer
781
- Args:
782
- spec_size (int | tuple(int)): Input Spectrogram size. Default 256
783
- patch_size (int | tuple(int)): Patch size. Default: 4
784
- path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
785
- in_chans (int): Number of input image channels. Default: 1 (mono)
786
- num_classes (int): Number of classes for classification head. Default: 527
787
- embed_dim (int): Patch embedding dimension. Default: 96
788
- depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
789
- num_heads (tuple(int)): Number of attention heads in different layers.
790
- window_size (int): Window size. Default: 8
791
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
792
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
793
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
794
- drop_rate (float): Dropout rate. Default: 0
795
- attn_drop_rate (float): Attention dropout rate. Default: 0
796
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
797
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
798
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
799
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
800
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
801
- config (module): The configuration Module from config.py
802
- """
803
-
804
- def __init__(
805
- self,
806
- spec_size=256,
807
- patch_size=4,
808
- patch_stride=(4, 4),
809
- in_chans=1,
810
- num_classes=527,
811
- embed_dim=96,
812
- depths=[2, 2, 6, 2],
813
- num_heads=[4, 8, 16, 32],
814
- window_size=8,
815
- mlp_ratio=4.0,
816
- qkv_bias=True,
817
- qk_scale=None,
818
- drop_rate=0.0,
819
- attn_drop_rate=0.0,
820
- drop_path_rate=0.1,
821
- norm_layer=nn.LayerNorm,
822
- ape=False,
823
- patch_norm=True,
824
- use_checkpoint=False,
825
- norm_before_mlp="ln",
826
- config=None,
827
- enable_fusion=False,
828
- fusion_type="None",
829
- **kwargs,
830
- ):
831
- super(HTSAT_Swin_Transformer, self).__init__()
832
-
833
- self.config = config
834
- self.spec_size = spec_size
835
- self.patch_stride = patch_stride
836
- self.patch_size = patch_size
837
- self.window_size = window_size
838
- self.embed_dim = embed_dim
839
- self.depths = depths
840
- self.ape = ape
841
- self.in_chans = in_chans
842
- self.num_classes = num_classes
843
- self.num_heads = num_heads
844
- self.num_layers = len(self.depths)
845
- self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
846
-
847
- self.drop_rate = drop_rate
848
- self.attn_drop_rate = attn_drop_rate
849
- self.drop_path_rate = drop_path_rate
850
-
851
- self.qkv_bias = qkv_bias
852
- self.qk_scale = None
853
-
854
- self.patch_norm = patch_norm
855
- self.norm_layer = norm_layer if self.patch_norm else None
856
- self.norm_before_mlp = norm_before_mlp
857
- self.mlp_ratio = mlp_ratio
858
-
859
- self.use_checkpoint = use_checkpoint
860
-
861
- self.enable_fusion = enable_fusion
862
- self.fusion_type = fusion_type
863
-
864
- # process mel-spec ; used only once
865
- self.freq_ratio = self.spec_size // self.config.mel_bins
866
- window = "hann"
867
- center = True
868
- pad_mode = "reflect"
869
- ref = 1.0
870
- amin = 1e-10
871
- top_db = None
872
- self.interpolate_ratio = 32 # Downsampled ratio
873
- # Spectrogram extractor
874
- self.spectrogram_extractor = Spectrogram(
875
- n_fft=config.window_size,
876
- hop_length=config.hop_size,
877
- win_length=config.window_size,
878
- window=window,
879
- center=center,
880
- pad_mode=pad_mode,
881
- freeze_parameters=True,
882
- )
883
- # Logmel feature extractor
884
- self.logmel_extractor = LogmelFilterBank(
885
- sr=config.sample_rate,
886
- n_fft=config.window_size,
887
- n_mels=config.mel_bins,
888
- fmin=config.fmin,
889
- fmax=config.fmax,
890
- ref=ref,
891
- amin=amin,
892
- top_db=top_db,
893
- freeze_parameters=True,
894
- )
895
- # Spec augmenter
896
- self.spec_augmenter = SpecAugmentation(
897
- time_drop_width=64,
898
- time_stripes_num=2,
899
- freq_drop_width=8,
900
- freq_stripes_num=2,
901
- ) # 2 2
902
- self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
903
-
904
- # split spctrogram into non-overlapping patches
905
- self.patch_embed = PatchEmbed(
906
- img_size=self.spec_size,
907
- patch_size=self.patch_size,
908
- in_chans=self.in_chans,
909
- embed_dim=self.embed_dim,
910
- norm_layer=self.norm_layer,
911
- patch_stride=patch_stride,
912
- enable_fusion=self.enable_fusion,
913
- fusion_type=self.fusion_type,
914
- )
915
-
916
- num_patches = self.patch_embed.num_patches
917
- patches_resolution = self.patch_embed.grid_size
918
- self.patches_resolution = patches_resolution
919
-
920
- # absolute position embedding
921
- if self.ape:
922
- self.absolute_pos_embed = nn.Parameter(
923
- torch.zeros(1, num_patches, self.embed_dim)
924
- )
925
- trunc_normal_(self.absolute_pos_embed, std=0.02)
926
-
927
- self.pos_drop = nn.Dropout(p=self.drop_rate)
928
-
929
- # stochastic depth
930
- dpr = [
931
- x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
932
- ] # stochastic depth decay rule
933
-
934
- # build layers
935
- self.layers = nn.ModuleList()
936
- for i_layer in range(self.num_layers):
937
- layer = BasicLayer(
938
- dim=int(self.embed_dim * 2**i_layer),
939
- input_resolution=(
940
- patches_resolution[0] // (2**i_layer),
941
- patches_resolution[1] // (2**i_layer),
942
- ),
943
- depth=self.depths[i_layer],
944
- num_heads=self.num_heads[i_layer],
945
- window_size=self.window_size,
946
- mlp_ratio=self.mlp_ratio,
947
- qkv_bias=self.qkv_bias,
948
- qk_scale=self.qk_scale,
949
- drop=self.drop_rate,
950
- attn_drop=self.attn_drop_rate,
951
- drop_path=dpr[
952
- sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
953
- ],
954
- norm_layer=self.norm_layer,
955
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
956
- use_checkpoint=use_checkpoint,
957
- norm_before_mlp=self.norm_before_mlp,
958
- )
959
- self.layers.append(layer)
960
-
961
- self.norm = self.norm_layer(self.num_features)
962
- self.avgpool = nn.AdaptiveAvgPool1d(1)
963
- self.maxpool = nn.AdaptiveMaxPool1d(1)
964
-
965
- SF = (
966
- self.spec_size
967
- // (2 ** (len(self.depths) - 1))
968
- // self.patch_stride[0]
969
- // self.freq_ratio
970
- )
971
- self.tscam_conv = nn.Conv2d(
972
- in_channels=self.num_features,
973
- out_channels=self.num_classes,
974
- kernel_size=(SF, 3),
975
- padding=(0, 1),
976
- )
977
- self.head = nn.Linear(num_classes, num_classes)
978
-
979
- if (self.enable_fusion) and (
980
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
981
- ):
982
- self.mel_conv1d = nn.Sequential(
983
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
984
- nn.BatchNorm1d(64),
985
- )
986
- if self.fusion_type == "daf_1d":
987
- self.fusion_model = DAF()
988
- elif self.fusion_type == "aff_1d":
989
- self.fusion_model = AFF(channels=64, type="1D")
990
- elif self.fusion_type == "iaff_1d":
991
- self.fusion_model = iAFF(channels=64, type="1D")
992
-
993
- self.apply(self._init_weights)
994
-
995
- def _init_weights(self, m):
996
- if isinstance(m, nn.Linear):
997
- trunc_normal_(m.weight, std=0.02)
998
- if isinstance(m, nn.Linear) and m.bias is not None:
999
- nn.init.constant_(m.bias, 0)
1000
- elif isinstance(m, nn.LayerNorm):
1001
- nn.init.constant_(m.bias, 0)
1002
- nn.init.constant_(m.weight, 1.0)
1003
-
1004
- @torch.jit.ignore
1005
- def no_weight_decay(self):
1006
- return {"absolute_pos_embed"}
1007
-
1008
- @torch.jit.ignore
1009
- def no_weight_decay_keywords(self):
1010
- return {"relative_position_bias_table"}
1011
-
1012
- def forward_features(self, x, longer_idx=None):
1013
- # A deprecated optimization for using a hierarchical output from different blocks
1014
-
1015
- frames_num = x.shape[2]
1016
- x = self.patch_embed(x, longer_idx=longer_idx)
1017
- if self.ape:
1018
- x = x + self.absolute_pos_embed
1019
- x = self.pos_drop(x)
1020
- for i, layer in enumerate(self.layers):
1021
- x, attn = layer(x)
1022
- # for x
1023
- x = self.norm(x)
1024
- B, N, C = x.shape
1025
- SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1026
- ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1027
- x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1028
- B, C, F, T = x.shape
1029
- # group 2D CNN
1030
- c_freq_bin = F // self.freq_ratio
1031
- x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1032
- x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1033
- # get latent_output
1034
- fine_grained_latent_output = torch.mean(x, dim=2)
1035
- fine_grained_latent_output = interpolate(
1036
- fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1037
- 8 * self.patch_stride[1],
1038
- )
1039
-
1040
- latent_output = self.avgpool(torch.flatten(x, 2))
1041
- latent_output = torch.flatten(latent_output, 1)
1042
-
1043
- # display the attention map, if needed
1044
-
1045
- x = self.tscam_conv(x)
1046
- x = torch.flatten(x, 2) # B, C, T
1047
-
1048
- fpx = interpolate(
1049
- torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1050
- )
1051
-
1052
- x = self.avgpool(x)
1053
- x = torch.flatten(x, 1)
1054
-
1055
- output_dict = {
1056
- "framewise_output": fpx, # already sigmoided
1057
- "clipwise_output": torch.sigmoid(x),
1058
- "fine_grained_embedding": fine_grained_latent_output,
1059
- "embedding": latent_output,
1060
- }
1061
-
1062
- return output_dict
1063
-
1064
- def crop_wav(self, x, crop_size, spe_pos=None):
1065
- time_steps = x.shape[2]
1066
- tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1067
- for i in range(len(x)):
1068
- if spe_pos is None:
1069
- crop_pos = random.randint(0, time_steps - crop_size - 1)
1070
- else:
1071
- crop_pos = spe_pos
1072
- tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1073
- return tx
1074
-
1075
- # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1076
- def reshape_wav2img(self, x):
1077
- B, C, T, F = x.shape
1078
- target_T = int(self.spec_size * self.freq_ratio)
1079
- target_F = self.spec_size // self.freq_ratio
1080
- assert (
1081
- T <= target_T and F <= target_F
1082
- ), "the wav size should less than or equal to the swin input size"
1083
- # to avoid bicubic zero error
1084
- if T < target_T:
1085
- x = nn.functional.interpolate(
1086
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1087
- )
1088
- if F < target_F:
1089
- x = nn.functional.interpolate(
1090
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1091
- )
1092
- x = x.permute(0, 1, 3, 2).contiguous()
1093
- x = x.reshape(
1094
- x.shape[0],
1095
- x.shape[1],
1096
- x.shape[2],
1097
- self.freq_ratio,
1098
- x.shape[3] // self.freq_ratio,
1099
- )
1100
- # print(x.shape)
1101
- x = x.permute(0, 1, 3, 2, 4).contiguous()
1102
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1103
- return x
1104
-
1105
- # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1106
- def repeat_wat2img(self, x, cur_pos):
1107
- B, C, T, F = x.shape
1108
- target_T = int(self.spec_size * self.freq_ratio)
1109
- target_F = self.spec_size // self.freq_ratio
1110
- assert (
1111
- T <= target_T and F <= target_F
1112
- ), "the wav size should less than or equal to the swin input size"
1113
- # to avoid bicubic zero error
1114
- if T < target_T:
1115
- x = nn.functional.interpolate(
1116
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1117
- )
1118
- if F < target_F:
1119
- x = nn.functional.interpolate(
1120
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1121
- )
1122
- x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1123
- x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1124
- x = x.repeat(repeats=(1, 1, 4, 1))
1125
- return x
1126
-
1127
- def forward(
1128
- self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1129
- ): # out_feat_keys: List[str] = None):
1130
-
1131
- if self.enable_fusion and x["longer"].sum() == 0:
1132
- # if no audio is longer than 10s, then randomly select one audio to be longer
1133
- x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1134
-
1135
- if not self.enable_fusion:
1136
- x = x["waveform"].to(device=device, non_blocking=True)
1137
- x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1138
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1139
- x = x.transpose(1, 3)
1140
- x = self.bn0(x)
1141
- x = x.transpose(1, 3)
1142
- if self.training:
1143
- x = self.spec_augmenter(x)
1144
-
1145
- if self.training and mixup_lambda is not None:
1146
- x = do_mixup(x, mixup_lambda)
1147
-
1148
- x = self.reshape_wav2img(x)
1149
- output_dict = self.forward_features(x)
1150
- else:
1151
- longer_list = x["longer"].to(device=device, non_blocking=True)
1152
- x = x["mel_fusion"].to(device=device, non_blocking=True)
1153
- x = x.transpose(1, 3)
1154
- x = self.bn0(x)
1155
- x = x.transpose(1, 3)
1156
- longer_list_idx = torch.where(longer_list)[0]
1157
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1158
- new_x = x[:, 0:1, :, :].clone().contiguous()
1159
- if len(longer_list_idx) > 0:
1160
- # local processing
1161
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1162
- FB, FC, FT, FF = fusion_x_local.size()
1163
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1164
- fusion_x_local = torch.permute(
1165
- fusion_x_local, (0, 2, 1)
1166
- ).contiguous()
1167
- fusion_x_local = self.mel_conv1d(fusion_x_local)
1168
- fusion_x_local = fusion_x_local.view(
1169
- FB, FC, FF, fusion_x_local.size(-1)
1170
- )
1171
- fusion_x_local = (
1172
- torch.permute(fusion_x_local, (0, 2, 1, 3))
1173
- .contiguous()
1174
- .flatten(2)
1175
- )
1176
- if fusion_x_local.size(-1) < FT:
1177
- fusion_x_local = torch.cat(
1178
- [
1179
- fusion_x_local,
1180
- torch.zeros(
1181
- (FB, FF, FT - fusion_x_local.size(-1)),
1182
- device=device,
1183
- ),
1184
- ],
1185
- dim=-1,
1186
- )
1187
- else:
1188
- fusion_x_local = fusion_x_local[:, :, :FT]
1189
- # 1D fusion
1190
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1191
- new_x[longer_list_idx] = self.fusion_model(
1192
- new_x[longer_list_idx], fusion_x_local
1193
- )
1194
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1195
- else:
1196
- x = new_x
1197
-
1198
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1199
- x = x # no change
1200
-
1201
- if self.training:
1202
- x = self.spec_augmenter(x)
1203
- if self.training and mixup_lambda is not None:
1204
- x = do_mixup(x, mixup_lambda)
1205
-
1206
- x = self.reshape_wav2img(x)
1207
- output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1208
-
1209
- # if infer_mode:
1210
- # # in infer mode. we need to handle different length audio input
1211
- # frame_num = x.shape[2]
1212
- # target_T = int(self.spec_size * self.freq_ratio)
1213
- # repeat_ratio = math.floor(target_T / frame_num)
1214
- # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1215
- # x = self.reshape_wav2img(x)
1216
- # output_dict = self.forward_features(x)
1217
- # else:
1218
- # if x.shape[2] > self.freq_ratio * self.spec_size:
1219
- # if self.training:
1220
- # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1221
- # x = self.reshape_wav2img(x)
1222
- # output_dict = self.forward_features(x)
1223
- # else:
1224
- # # Change: Hard code here
1225
- # overlap_size = (x.shape[2] - 1) // 4
1226
- # output_dicts = []
1227
- # crop_size = (x.shape[2] - 1) // 2
1228
- # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1229
- # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1230
- # tx = self.reshape_wav2img(tx)
1231
- # output_dicts.append(self.forward_features(tx))
1232
- # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1233
- # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1234
- # for d in output_dicts:
1235
- # clipwise_output += d["clipwise_output"]
1236
- # framewise_output += d["framewise_output"]
1237
- # clipwise_output = clipwise_output / len(output_dicts)
1238
- # framewise_output = framewise_output / len(output_dicts)
1239
- # output_dict = {
1240
- # 'framewise_output': framewise_output,
1241
- # 'clipwise_output': clipwise_output
1242
- # }
1243
- # else: # this part is typically used, and most easy one
1244
- # x = self.reshape_wav2img(x)
1245
- # output_dict = self.forward_features(x)
1246
- # x = self.head(x)
1247
-
1248
- # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1249
-
1250
- return output_dict
1251
-
1252
-
1253
- def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1254
- try:
1255
-
1256
- assert audio_cfg.model_name in [
1257
- "tiny",
1258
- "base",
1259
- "large",
1260
- ], "model name for HTS-AT is wrong!"
1261
- if audio_cfg.model_name == "tiny":
1262
- model = HTSAT_Swin_Transformer(
1263
- spec_size=256,
1264
- patch_size=4,
1265
- patch_stride=(4, 4),
1266
- num_classes=audio_cfg.class_num,
1267
- embed_dim=96,
1268
- depths=[2, 2, 6, 2],
1269
- num_heads=[4, 8, 16, 32],
1270
- window_size=8,
1271
- config=audio_cfg,
1272
- enable_fusion=enable_fusion,
1273
- fusion_type=fusion_type,
1274
- )
1275
- elif audio_cfg.model_name == "base":
1276
- model = HTSAT_Swin_Transformer(
1277
- spec_size=256,
1278
- patch_size=4,
1279
- patch_stride=(4, 4),
1280
- num_classes=audio_cfg.class_num,
1281
- embed_dim=128,
1282
- depths=[2, 2, 12, 2],
1283
- num_heads=[4, 8, 16, 32],
1284
- window_size=8,
1285
- config=audio_cfg,
1286
- enable_fusion=enable_fusion,
1287
- fusion_type=fusion_type,
1288
- )
1289
- elif audio_cfg.model_name == "large":
1290
- model = HTSAT_Swin_Transformer(
1291
- spec_size=256,
1292
- patch_size=4,
1293
- patch_stride=(4, 4),
1294
- num_classes=audio_cfg.class_num,
1295
- embed_dim=256,
1296
- depths=[2, 2, 12, 2],
1297
- num_heads=[4, 8, 16, 32],
1298
- window_size=8,
1299
- config=audio_cfg,
1300
- enable_fusion=enable_fusion,
1301
- fusion_type=fusion_type,
1302
- )
1303
-
1304
- return model
1305
- except:
1306
- raise RuntimeError(
1307
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1308
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/linear_probe.py DELETED
@@ -1,66 +0,0 @@
1
- import numpy as np
2
- import torch.nn.functional as F
3
- from torch import nn
4
- from .model import MLPLayers
5
-
6
-
7
- class LinearProbe(nn.Module):
8
- def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None):
9
- """
10
- Args:
11
- model: nn.Module
12
- mlp: bool, if True, then use the MLP layer as the linear probe module
13
- freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe
14
- in_ch: int, the output channel from CLAP model
15
- out_ch: int, the output channel from linear probe (class_num)
16
- act: torch.nn.functional, the activation function before the loss function
17
- """
18
- super().__init__()
19
- in_ch = 512
20
- self.clap_model = model
21
- self.clap_model.text_branch = None # to save memory
22
- self.freeze = freeze
23
- if mlp:
24
- self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch])
25
- else:
26
- self.lp_layer = nn.Linear(in_ch, out_ch)
27
-
28
- if self.freeze:
29
- for param in self.clap_model.parameters():
30
- param.requires_grad = False
31
-
32
- if act == "None":
33
- self.act = None
34
- elif act == "relu":
35
- self.act = nn.ReLU()
36
- elif act == "elu":
37
- self.act = nn.ELU()
38
- elif act == "prelu":
39
- self.act = nn.PReLU(num_parameters=in_ch)
40
- elif act == "softmax":
41
- self.act = nn.Softmax(dim=-1)
42
- elif act == "sigmoid":
43
- self.act = nn.Sigmoid()
44
-
45
- def forward(self, x, mix_lambda=None, device=None):
46
- """
47
- Args:
48
- x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list
49
- mix_lambda: torch.tensor [batch], the mixup lambda
50
- Returns:
51
- class_prob: torch.tensor [batch, class_num]
52
-
53
- """
54
- # batchnorm cancel grandient
55
- if self.freeze:
56
- self.clap_model.eval()
57
-
58
- x = self.clap_model.audio_projection(
59
- self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[
60
- "embedding"
61
- ]
62
- )
63
- out = self.lp_layer(x)
64
- if self.act is not None:
65
- out = self.act(out)
66
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/loss.py DELETED
@@ -1,398 +0,0 @@
1
- from multiprocessing.sharedctypes import Value
2
- import torch
3
- import torch.distributed.nn
4
- from torch import distributed as dist, nn as nn
5
- from torch.nn import functional as F
6
- import numpy as np
7
- from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
8
-
9
- try:
10
- import horovod.torch as hvd
11
- except ImportError:
12
- hvd = None
13
-
14
-
15
- def gather_features(
16
- audio_features,
17
- text_features,
18
- audio_features_mlp=None,
19
- text_features_mlp=None,
20
- local_loss=False,
21
- gather_with_grad=False,
22
- rank=0,
23
- world_size=1,
24
- use_horovod=False,
25
- mlp_loss=False,
26
- ):
27
- if use_horovod:
28
- assert hvd is not None, "Please install horovod"
29
- if gather_with_grad:
30
- all_audio_features = hvd.allgather(audio_features)
31
- all_text_features = hvd.allgather(text_features)
32
- if mlp_loss:
33
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
34
- all_text_features_mlp = hvd.allgather(text_features_mlp)
35
- else:
36
- with torch.no_grad():
37
- all_audio_features = hvd.allgather(audio_features)
38
- all_text_features = hvd.allgather(text_features)
39
- if mlp_loss:
40
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
41
- all_text_features_mlp = hvd.allgather(text_features_mlp)
42
- if not local_loss:
43
- # ensure grads for local rank when all_* features don't have a gradient
44
- gathered_audio_features = list(
45
- all_audio_features.chunk(world_size, dim=0)
46
- )
47
- gathered_text_features = list(
48
- all_text_features.chunk(world_size, dim=0)
49
- )
50
- gathered_audio_features[rank] = audio_features
51
- gathered_text_features[rank] = text_features
52
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
53
- all_text_features = torch.cat(gathered_text_features, dim=0)
54
- if mlp_loss:
55
- gathered_audio_features_mlp = list(
56
- all_audio_features_mlp.chunk(world_size, dim=0)
57
- )
58
- gathered_text_features_mlp = list(
59
- all_text_features_mlp.chunk(world_size, dim=0)
60
- )
61
- gathered_audio_features_mlp[rank] = audio_features_mlp
62
- gathered_text_features_mlp[rank] = text_features_mlp
63
- all_audio_features_mlp = torch.cat(
64
- gathered_audio_features_mlp, dim=0
65
- )
66
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
67
- else:
68
- # We gather tensors from all gpus
69
- if gather_with_grad:
70
- all_audio_features = torch.cat(
71
- torch.distributed.nn.all_gather(audio_features), dim=0
72
- )
73
- all_text_features = torch.cat(
74
- torch.distributed.nn.all_gather(text_features), dim=0
75
- )
76
- if mlp_loss:
77
- all_audio_features_mlp = torch.cat(
78
- torch.distributed.nn.all_gather(audio_features_mlp), dim=0
79
- )
80
- all_text_features_mlp = torch.cat(
81
- torch.distributed.nn.all_gather(text_features_mlp), dim=0
82
- )
83
- else:
84
- gathered_audio_features = [
85
- torch.zeros_like(audio_features) for _ in range(world_size)
86
- ]
87
- gathered_text_features = [
88
- torch.zeros_like(text_features) for _ in range(world_size)
89
- ]
90
- dist.all_gather(gathered_audio_features, audio_features)
91
- dist.all_gather(gathered_text_features, text_features)
92
- if mlp_loss:
93
- gathered_audio_features_mlp = [
94
- torch.zeros_like(audio_features_mlp) for _ in range(world_size)
95
- ]
96
- gathered_text_features_mlp = [
97
- torch.zeros_like(text_features_mlp) for _ in range(world_size)
98
- ]
99
- dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
100
- dist.all_gather(gathered_text_features_mlp, text_features_mlp)
101
- if not local_loss:
102
- # ensure grads for local rank when all_* features don't have a gradient
103
- gathered_audio_features[rank] = audio_features
104
- gathered_text_features[rank] = text_features
105
- if mlp_loss:
106
- gathered_audio_features_mlp[rank] = audio_features_mlp
107
- gathered_text_features_mlp[rank] = text_features_mlp
108
-
109
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
110
- all_text_features = torch.cat(gathered_text_features, dim=0)
111
- if mlp_loss:
112
- all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
113
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
114
- if mlp_loss:
115
- return (
116
- all_audio_features,
117
- all_text_features,
118
- all_audio_features_mlp,
119
- all_text_features_mlp,
120
- )
121
- else:
122
- return all_audio_features, all_text_features
123
-
124
-
125
- class ClipLoss(nn.Module):
126
- def __init__(
127
- self,
128
- local_loss=False,
129
- gather_with_grad=False,
130
- cache_labels=False,
131
- rank=0,
132
- world_size=1,
133
- use_horovod=False,
134
- mlp_loss=False,
135
- weight_loss_kappa=0,
136
- ):
137
- super().__init__()
138
- self.local_loss = local_loss
139
- self.gather_with_grad = gather_with_grad
140
- self.cache_labels = cache_labels
141
- self.rank = rank
142
- self.world_size = world_size
143
- self.use_horovod = use_horovod
144
- self.mlp_loss = mlp_loss
145
- self.weighted_loss = bool(weight_loss_kappa != 0)
146
- self.weight_loss_kappa = weight_loss_kappa
147
- # cache state
148
- self.prev_num_logits = 0
149
- self.labels = {}
150
-
151
- def forward(
152
- self,
153
- audio_features,
154
- text_features,
155
- logit_scale_a,
156
- logit_scale_t=None,
157
- audio_features_mlp=None,
158
- text_features_mlp=None,
159
- ):
160
- device = audio_features.device
161
- if self.mlp_loss:
162
- if self.world_size > 1:
163
- (
164
- all_audio_features,
165
- all_text_features,
166
- all_audio_features_mlp,
167
- all_text_features_mlp,
168
- ) = gather_features(
169
- audio_features=audio_features,
170
- text_features=text_features,
171
- audio_features_mlp=audio_features_mlp,
172
- text_features_mlp=text_features_mlp,
173
- local_loss=self.local_loss,
174
- gather_with_grad=self.gather_with_grad,
175
- rank=self.rank,
176
- world_size=self.world_size,
177
- use_horovod=self.use_horovod,
178
- mlp_loss=self.mlp_loss,
179
- )
180
- if self.local_loss:
181
- a_logits_per_audio = (
182
- logit_scale_a * audio_features @ all_text_features_mlp.T
183
- )
184
- a_logits_per_text = (
185
- logit_scale_a * text_features_mlp @ all_audio_features.T
186
- )
187
- t_logits_per_audio = (
188
- logit_scale_t * audio_features_mlp @ all_text_features.T
189
- )
190
- t_logits_per_text = (
191
- logit_scale_t * text_features @ all_audio_features_mlp.T
192
- )
193
- else:
194
- a_logits_per_audio = (
195
- logit_scale_a * all_audio_features @ all_text_features_mlp.T
196
- )
197
- a_logits_per_text = a_logits_per_audio.T
198
- t_logits_per_audio = (
199
- logit_scale_t * all_audio_features_mlp @ all_text_features.T
200
- )
201
- t_logits_per_text = t_logits_per_audio.T
202
- else:
203
- a_logits_per_audio = (
204
- logit_scale_a * audio_features @ text_features_mlp.T
205
- )
206
- a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
207
- t_logits_per_audio = (
208
- logit_scale_t * audio_features_mlp @ text_features.T
209
- )
210
- t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
211
-
212
- # calculated ground-truth and cache if enabled
213
- num_logits = a_logits_per_audio.shape[0]
214
- if self.prev_num_logits != num_logits or device not in self.labels:
215
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
216
- if self.world_size > 1 and self.local_loss:
217
- labels = labels + num_logits * self.rank
218
- if self.cache_labels:
219
- self.labels[device] = labels
220
- self.prev_num_logits = num_logits
221
- else:
222
- labels = self.labels[device]
223
-
224
- if not self.weighted_loss:
225
- total_loss = (
226
- F.cross_entropy(a_logits_per_audio, labels)
227
- + F.cross_entropy(a_logits_per_text, labels)
228
- + F.cross_entropy(t_logits_per_audio, labels)
229
- + F.cross_entropy(t_logits_per_text, labels)
230
- ) / 4
231
- else:
232
- audio_weight = (audio_features @ audio_features.T).detach()
233
- audio_weight = (
234
- torch.exp(
235
- torch.sum(audio_weight, axis=1)
236
- / (self.weight_loss_kappa * len(audio_weight))
237
- )
238
- ).detach()
239
- text_weight = (text_features @ text_features.T).detach()
240
- text_weight = (
241
- torch.exp(
242
- torch.sum(text_weight, axis=1)
243
- / (self.weight_loss_kappa * len(text_features))
244
- )
245
- ).detach()
246
- total_loss = (
247
- F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
248
- + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
249
- + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
250
- + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
251
- ) / 4
252
- else:
253
- if self.world_size > 1:
254
- all_audio_features, all_text_features = gather_features(
255
- audio_features=audio_features,
256
- text_features=text_features,
257
- local_loss=self.local_loss,
258
- gather_with_grad=self.gather_with_grad,
259
- rank=self.rank,
260
- world_size=self.world_size,
261
- use_horovod=self.use_horovod,
262
- mlp_loss=self.mlp_loss,
263
- )
264
-
265
- if self.local_loss:
266
- logits_per_audio = (
267
- logit_scale_a * audio_features @ all_text_features.T
268
- )
269
- logits_per_text = (
270
- logit_scale_a * text_features @ all_audio_features.T
271
- )
272
- else:
273
- logits_per_audio = (
274
- logit_scale_a * all_audio_features @ all_text_features.T
275
- )
276
- logits_per_text = logits_per_audio.T
277
- else:
278
- logits_per_audio = logit_scale_a * audio_features @ text_features.T
279
- logits_per_text = logit_scale_a * text_features @ audio_features.T
280
-
281
- # calculated ground-truth and cache if enabled
282
- num_logits = logits_per_audio.shape[0]
283
- if self.prev_num_logits != num_logits or device not in self.labels:
284
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
285
- if self.world_size > 1 and self.local_loss:
286
- labels = labels + num_logits * self.rank
287
- if self.cache_labels:
288
- self.labels[device] = labels
289
- self.prev_num_logits = num_logits
290
- else:
291
- labels = self.labels[device]
292
- if not self.weighted_loss:
293
- total_loss = (
294
- F.cross_entropy(logits_per_audio, labels)
295
- + F.cross_entropy(logits_per_text, labels)
296
- ) / 2
297
- else:
298
- audio_weight = (all_audio_features @ all_audio_features.T).detach()
299
- audio_weight = (
300
- torch.exp(
301
- torch.sum(audio_weight, axis=1)
302
- / (self.weight_loss_kappa * len(all_audio_features))
303
- )
304
- ).detach()
305
- text_weight = (all_text_features @ all_text_features.T).detach()
306
- text_weight = (
307
- torch.exp(
308
- torch.sum(text_weight, axis=1)
309
- / (self.weight_loss_kappa * len(all_text_features))
310
- )
311
- ).detach()
312
- total_loss = (
313
- F.cross_entropy(logits_per_audio, labels, weight=text_weight)
314
- + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
315
- ) / 2
316
- return total_loss
317
-
318
-
319
- def lp_gather_features(pred, target, world_size=1, use_horovod=False):
320
- if use_horovod:
321
- assert hvd is not None, "Please install horovod"
322
- with torch.no_grad():
323
- all_preds = hvd.allgather(pred)
324
- all_targets = hvd.allgath(target)
325
- else:
326
- gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
327
- gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
328
-
329
- dist.all_gather(gathered_preds, pred)
330
- dist.all_gather(gathered_targets, target)
331
- all_preds = torch.cat(gathered_preds, dim=0)
332
- all_targets = torch.cat(gathered_targets, dim=0)
333
-
334
- return all_preds, all_targets
335
-
336
-
337
- def get_map(pred, target):
338
- pred = torch.sigmoid(pred).numpy()
339
- target = target.numpy()
340
- return np.mean(average_precision_score(target, pred, average=None))
341
-
342
-
343
- def get_acc(pred, target):
344
- pred = torch.argmax(pred, 1).numpy()
345
- target = torch.argmax(target, 1).numpy()
346
- return accuracy_score(target, pred)
347
-
348
-
349
- def get_mauc(pred, target):
350
- pred = torch.sigmoid(pred).numpy()
351
- target = target.numpy()
352
- return np.mean(roc_auc_score(target, pred, average=None))
353
-
354
-
355
- class LPMetrics(object):
356
- def __init__(self, metric_names=["map", "acc", "mauc"]):
357
- self.metrics = []
358
- for name in metric_names:
359
- self.metrics.append(self.get_metric(name))
360
- self.metric_names = metric_names
361
-
362
- def get_metric(self, name):
363
- if name == "map":
364
- return get_map
365
- elif name == "acc":
366
- return get_acc
367
- elif name == "mauc":
368
- return get_mauc
369
- else:
370
- raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
371
-
372
- def evaluate_mertics(self, pred, target):
373
- metric_dict = {}
374
- for i in range(len(self.metric_names)):
375
- metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
376
- return metric_dict
377
-
378
-
379
- def calc_celoss(pred, target):
380
- target = torch.argmax(target, 1).long()
381
- return nn.CrossEntropyLoss()(pred, target)
382
-
383
-
384
- class LPLoss(nn.Module):
385
- def __init__(self, loss_name):
386
- super().__init__()
387
- if loss_name == "bce":
388
- self.loss_func = nn.BCEWithLogitsLoss()
389
- elif loss_name == "ce":
390
- self.loss_func = calc_celoss
391
- elif loss_name == "mse":
392
- self.loss_func = nn.MSELoss()
393
- else:
394
- raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
395
-
396
- def forward(self, pred, target):
397
- loss = self.loss_func(pred, target)
398
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model.py DELETED
@@ -1,936 +0,0 @@
1
- """ CLAP Model
2
-
3
- Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- Adapted to the Audio Task.
5
- """
6
-
7
- from collections import OrderedDict
8
- from dataclasses import dataclass
9
- from email.mime import audio
10
- from typing import Tuple, Union, Callable, Optional
11
-
12
- import numpy as np
13
- import torch
14
- import torch.nn.functional as F
15
- from torch import nn
16
-
17
- from .timm_model import TimmModel
18
- import logging
19
- from .utils import freeze_batch_norm_2d
20
-
21
- from .pann_model import create_pann_model
22
- from .htsat import create_htsat_model
23
- from transformers import BertModel, RobertaModel, BartModel
24
- from transformers.tokenization_utils_base import BatchEncoding
25
-
26
-
27
- class MLPLayers(nn.Module):
28
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
29
- super(MLPLayers, self).__init__()
30
- self.nonlin = nonlin
31
- self.dropout = dropout
32
-
33
- sequence = []
34
- for u0, u1 in zip(units[:-1], units[1:]):
35
- sequence.append(nn.Linear(u0, u1))
36
- sequence.append(self.nonlin)
37
- sequence.append(nn.Dropout(self.dropout))
38
- sequence = sequence[:-2]
39
-
40
- self.sequential = nn.Sequential(*sequence)
41
-
42
- def forward(self, X):
43
- X = self.sequential(X)
44
- return X
45
-
46
-
47
- class Bottleneck(nn.Module):
48
- expansion = 4
49
-
50
- def __init__(self, inplanes, planes, stride=1):
51
- super().__init__()
52
-
53
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
- self.bn1 = nn.BatchNorm2d(planes)
56
-
57
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
58
- self.bn2 = nn.BatchNorm2d(planes)
59
-
60
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
61
-
62
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
63
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
64
-
65
- self.relu = nn.ReLU(inplace=True)
66
- self.downsample = None
67
- self.stride = stride
68
-
69
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
70
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
71
- self.downsample = nn.Sequential(
72
- OrderedDict(
73
- [
74
- ("-1", nn.AvgPool2d(stride)),
75
- (
76
- "0",
77
- nn.Conv2d(
78
- inplanes,
79
- planes * self.expansion,
80
- 1,
81
- stride=1,
82
- bias=False,
83
- ),
84
- ),
85
- ("1", nn.BatchNorm2d(planes * self.expansion)),
86
- ]
87
- )
88
- )
89
-
90
- def forward(self, x: torch.Tensor):
91
- identity = x
92
-
93
- out = self.relu(self.bn1(self.conv1(x)))
94
- out = self.relu(self.bn2(self.conv2(out)))
95
- out = self.avgpool(out)
96
- out = self.bn3(self.conv3(out))
97
-
98
- if self.downsample is not None:
99
- identity = self.downsample(x)
100
-
101
- out += identity
102
- out = self.relu(out)
103
- return out
104
-
105
-
106
- class AttentionPool2d(nn.Module):
107
- def __init__(
108
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
109
- ):
110
- super().__init__()
111
- self.positional_embedding = nn.Parameter(
112
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
113
- )
114
- self.k_proj = nn.Linear(embed_dim, embed_dim)
115
- self.q_proj = nn.Linear(embed_dim, embed_dim)
116
- self.v_proj = nn.Linear(embed_dim, embed_dim)
117
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
118
- self.num_heads = num_heads
119
-
120
- def forward(self, x):
121
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
122
- 2, 0, 1
123
- ) # NCHW -> (HW)NC
124
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
125
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
126
- x, _ = F.multi_head_attention_forward(
127
- query=x,
128
- key=x,
129
- value=x,
130
- embed_dim_to_check=x.shape[-1],
131
- num_heads=self.num_heads,
132
- q_proj_weight=self.q_proj.weight,
133
- k_proj_weight=self.k_proj.weight,
134
- v_proj_weight=self.v_proj.weight,
135
- in_proj_weight=None,
136
- in_proj_bias=torch.cat(
137
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
138
- ),
139
- bias_k=None,
140
- bias_v=None,
141
- add_zero_attn=False,
142
- dropout_p=0,
143
- out_proj_weight=self.c_proj.weight,
144
- out_proj_bias=self.c_proj.bias,
145
- use_separate_proj_weight=True,
146
- training=self.training,
147
- need_weights=False,
148
- )
149
-
150
- return x[0]
151
-
152
-
153
- class ModifiedResNet(nn.Module):
154
- """
155
- A ResNet class that is similar to torchvision's but contains the following changes:
156
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
157
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
158
- - The final pooling layer is a QKV attention instead of an average pool
159
- """
160
-
161
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
162
- super().__init__()
163
- self.output_dim = output_dim
164
- self.image_size = image_size
165
-
166
- # the 3-layer stem
167
- self.conv1 = nn.Conv2d(
168
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
169
- )
170
- self.bn1 = nn.BatchNorm2d(width // 2)
171
- self.conv2 = nn.Conv2d(
172
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
173
- )
174
- self.bn2 = nn.BatchNorm2d(width // 2)
175
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
176
- self.bn3 = nn.BatchNorm2d(width)
177
- self.avgpool = nn.AvgPool2d(2)
178
- self.relu = nn.ReLU(inplace=True)
179
-
180
- # residual layers
181
- self._inplanes = width # this is a *mutable* variable used during construction
182
- self.layer1 = self._make_layer(width, layers[0])
183
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
184
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
185
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
186
-
187
- embed_dim = width * 32 # the ResNet feature dimension
188
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
189
-
190
- self.init_parameters()
191
-
192
- def _make_layer(self, planes, blocks, stride=1):
193
- layers = [Bottleneck(self._inplanes, planes, stride)]
194
-
195
- self._inplanes = planes * Bottleneck.expansion
196
- for _ in range(1, blocks):
197
- layers.append(Bottleneck(self._inplanes, planes))
198
-
199
- return nn.Sequential(*layers)
200
-
201
- def init_parameters(self):
202
- if self.attnpool is not None:
203
- std = self.attnpool.c_proj.in_features**-0.5
204
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
205
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
206
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
207
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
208
-
209
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
210
- for name, param in resnet_block.named_parameters():
211
- if name.endswith("bn3.weight"):
212
- nn.init.zeros_(param)
213
-
214
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
215
- assert (
216
- unlocked_groups == 0
217
- ), "partial locking not currently supported for this model"
218
- for param in self.parameters():
219
- param.requires_grad = False
220
- if freeze_bn_stats:
221
- freeze_batch_norm_2d(self)
222
-
223
- def stem(self, x):
224
- for conv, bn in [
225
- (self.conv1, self.bn1),
226
- (self.conv2, self.bn2),
227
- (self.conv3, self.bn3),
228
- ]:
229
- x = self.relu(bn(conv(x)))
230
- x = self.avgpool(x)
231
- return x
232
-
233
- def forward(self, x):
234
- x = self.stem(x)
235
- x = self.layer1(x)
236
- x = self.layer2(x)
237
- x = self.layer3(x)
238
- x = self.layer4(x)
239
- x = self.attnpool(x)
240
-
241
- return x
242
-
243
-
244
- class LayerNorm(nn.LayerNorm):
245
- """Subclass torch's LayerNorm to handle fp16."""
246
-
247
- def forward(self, x: torch.Tensor):
248
- orig_type = x.dtype
249
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
250
- return x.to(orig_type)
251
-
252
-
253
- class QuickGELU(nn.Module):
254
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
255
- def forward(self, x: torch.Tensor):
256
- return x * torch.sigmoid(1.702 * x)
257
-
258
-
259
- class ResidualAttentionBlock(nn.Module):
260
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
261
- super().__init__()
262
-
263
- self.attn = nn.MultiheadAttention(d_model, n_head)
264
- self.ln_1 = LayerNorm(d_model)
265
- self.mlp = nn.Sequential(
266
- OrderedDict(
267
- [
268
- ("c_fc", nn.Linear(d_model, d_model * 4)),
269
- ("gelu", act_layer()),
270
- ("c_proj", nn.Linear(d_model * 4, d_model)),
271
- ]
272
- )
273
- )
274
- self.ln_2 = LayerNorm(d_model)
275
-
276
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
278
-
279
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
280
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
281
- x = x + self.mlp(self.ln_2(x))
282
- return x
283
-
284
-
285
- class Transformer(nn.Module):
286
- def __init__(
287
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
288
- ):
289
- super().__init__()
290
- self.width = width
291
- self.layers = layers
292
- self.resblocks = nn.ModuleList(
293
- [
294
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
295
- for _ in range(layers)
296
- ]
297
- )
298
-
299
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
300
- for r in self.resblocks:
301
- x = r(x, attn_mask=attn_mask)
302
- return x
303
-
304
-
305
- class VisualTransformer(nn.Module):
306
- def __init__(
307
- self,
308
- image_size: int,
309
- patch_size: int,
310
- width: int,
311
- layers: int,
312
- heads: int,
313
- output_dim: int,
314
- act_layer: Callable = nn.GELU,
315
- ):
316
- super().__init__()
317
- self.image_size = image_size
318
- self.output_dim = output_dim
319
- self.conv1 = nn.Conv2d(
320
- in_channels=3,
321
- out_channels=width,
322
- kernel_size=patch_size,
323
- stride=patch_size,
324
- bias=False,
325
- )
326
-
327
- scale = width**-0.5
328
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
329
- self.positional_embedding = nn.Parameter(
330
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
331
- )
332
- self.ln_pre = LayerNorm(width)
333
-
334
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
335
-
336
- self.ln_post = LayerNorm(width)
337
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
338
-
339
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
340
- assert (
341
- unlocked_groups == 0
342
- ), "partial locking not currently supported for this model"
343
- for param in self.parameters():
344
- param.requires_grad = False
345
-
346
- def forward(self, x: torch.Tensor):
347
- x = self.conv1(x) # shape = [*, width, grid, grid]
348
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
349
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
350
- x = torch.cat(
351
- [
352
- self.class_embedding.to(x.dtype)
353
- + torch.zeros(
354
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
355
- ),
356
- x,
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- x = x + self.positional_embedding.to(x.dtype)
361
- x = self.ln_pre(x)
362
-
363
- x = x.permute(1, 0, 2) # NLD -> LND
364
- x = self.text_branch(x)
365
- x = x.permute(1, 0, 2) # LND -> NLD
366
-
367
- x = self.ln_post(x[:, 0, :])
368
-
369
- if self.proj is not None:
370
- x = x @ self.proj
371
-
372
- return x
373
-
374
-
375
- @dataclass
376
- class CLAPVisionCfg:
377
- layers: Union[Tuple[int, int, int, int], int] = 12
378
- width: int = 768
379
- patch_size: int = 16
380
- image_size: Union[Tuple[int, int], int] = 224
381
- timm_model_name: str = (
382
- None # a valid model name overrides layers, width, patch_size
383
- )
384
- timm_model_pretrained: bool = (
385
- False # use (imagenet) pretrained weights for named model
386
- )
387
- timm_pool: str = (
388
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
389
- )
390
- timm_proj: str = (
391
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
392
- )
393
-
394
-
395
- # Audio Config Class
396
- @dataclass
397
- class CLAPAudioCfp:
398
- model_type: str = "PANN"
399
- model_name: str = "Cnn14"
400
- sample_rate: int = 48000
401
- # Param
402
- audio_length: int = 1024
403
- window_size: int = 1024
404
- hop_size: int = 1024
405
- fmin: int = 50
406
- fmax: int = 14000
407
- class_num: int = 527
408
- mel_bins: int = 64
409
- clip_samples: int = 480000
410
-
411
-
412
- @dataclass
413
- class CLAPTextCfg:
414
- context_length: int
415
- vocab_size: int
416
- width: int
417
- heads: int
418
- layers: int
419
- model_type: str
420
-
421
-
422
- class CLAP(nn.Module):
423
- def __init__(
424
- self,
425
- embed_dim: int,
426
- audio_cfg: CLAPAudioCfp,
427
- text_cfg: CLAPTextCfg,
428
- quick_gelu: bool = False,
429
- enable_fusion: bool = False,
430
- fusion_type: str = "None",
431
- joint_embed_shape: int = 512,
432
- mlp_act: str = "relu",
433
- ):
434
- super().__init__()
435
- if isinstance(audio_cfg, dict):
436
- audio_cfg = CLAPAudioCfp(**audio_cfg)
437
- if isinstance(text_cfg, dict):
438
- text_cfg = CLAPTextCfg(**text_cfg)
439
-
440
- self.audio_cfg = audio_cfg
441
- self.text_cfg = text_cfg
442
- self.enable_fusion = enable_fusion
443
- self.fusion_type = fusion_type
444
- self.joint_embed_shape = joint_embed_shape
445
- self.mlp_act = mlp_act
446
-
447
- self.context_length = text_cfg.context_length
448
-
449
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
450
- # memory efficient in recent PyTorch releases (>= 1.10).
451
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
452
- act_layer = QuickGELU if quick_gelu else nn.GELU
453
-
454
- if mlp_act == "relu":
455
- mlp_act_layer = nn.ReLU()
456
- elif mlp_act == "gelu":
457
- mlp_act_layer = nn.GELU()
458
- else:
459
- raise NotImplementedError
460
-
461
- # audio branch
462
- # audio branch parameters
463
- if audio_cfg.model_type == "PANN":
464
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
465
- elif audio_cfg.model_type == "HTSAT":
466
- self.audio_branch = create_htsat_model(
467
- audio_cfg, enable_fusion, fusion_type
468
- )
469
- else:
470
- logging.error(f"Model config for {audio_cfg.model_type} not found")
471
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
472
-
473
- # text branch
474
- # text branch parameters
475
- if text_cfg.model_type == "transformer":
476
- self.text_branch = Transformer(
477
- width=text_cfg.width,
478
- layers=text_cfg.layers,
479
- heads=text_cfg.heads,
480
- act_layer=act_layer,
481
- )
482
- self.vocab_size = text_cfg.vocab_size
483
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
484
- self.positional_embedding = nn.Parameter(
485
- torch.empty(self.context_length, text_cfg.width)
486
- )
487
- self.ln_final = LayerNorm(text_cfg.width)
488
- self.text_transform = MLPLayers(
489
- units=[
490
- self.joint_embed_shape,
491
- self.joint_embed_shape,
492
- self.joint_embed_shape,
493
- ],
494
- dropout=0.1,
495
- )
496
- self.text_projection = nn.Sequential(
497
- nn.Linear(text_cfg.width, self.joint_embed_shape),
498
- mlp_act_layer,
499
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
500
- )
501
- elif text_cfg.model_type == "bert":
502
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
503
- self.text_transform = MLPLayers(
504
- units=[
505
- self.joint_embed_shape,
506
- self.joint_embed_shape,
507
- self.joint_embed_shape,
508
- ],
509
- dropout=0.1,
510
- )
511
- self.text_projection = nn.Sequential(
512
- nn.Linear(768, self.joint_embed_shape),
513
- mlp_act_layer,
514
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
515
- )
516
- elif text_cfg.model_type == "roberta":
517
- self.text_branch = RobertaModel.from_pretrained("roberta-base")
518
- self.text_transform = MLPLayers(
519
- units=[
520
- self.joint_embed_shape,
521
- self.joint_embed_shape,
522
- self.joint_embed_shape,
523
- ],
524
- dropout=0.1,
525
- )
526
- self.text_projection = nn.Sequential(
527
- nn.Linear(768, self.joint_embed_shape),
528
- mlp_act_layer,
529
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
530
- )
531
- elif text_cfg.model_type == "bart":
532
- self.text_branch = BartModel.from_pretrained("facebook/bart-base")
533
- self.text_transform = MLPLayers(
534
- units=[
535
- self.joint_embed_shape,
536
- self.joint_embed_shape,
537
- self.joint_embed_shape,
538
- ],
539
- dropout=0.1,
540
- )
541
- self.text_projection = nn.Sequential(
542
- nn.Linear(768, self.joint_embed_shape),
543
- mlp_act_layer,
544
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
545
- )
546
- else:
547
- logging.error(f"Model config for {text_cfg.model_type} not found")
548
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
549
- self.text_branch_type = text_cfg.model_type
550
- # text branch parameters
551
-
552
- # audio branch parameters
553
- self.audio_transform = MLPLayers(
554
- units=[
555
- self.joint_embed_shape,
556
- self.joint_embed_shape,
557
- self.joint_embed_shape,
558
- ],
559
- dropout=0.1,
560
- )
561
-
562
- # below here is text branch parameters
563
-
564
- # ============================================================================================================
565
- self.audio_projection = nn.Sequential(
566
- nn.Linear(embed_dim, self.joint_embed_shape),
567
- mlp_act_layer,
568
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
569
- )
570
-
571
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
573
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
574
-
575
- self.init_text_branch_parameters()
576
-
577
- def init_text_branch_parameters(self):
578
- if self.text_branch_type == "transformer":
579
- nn.init.normal_(self.token_embedding.weight, std=0.02)
580
- nn.init.normal_(self.positional_embedding, std=0.01)
581
- proj_std = (self.text_branch.width**-0.5) * (
582
- (2 * self.text_branch.layers) ** -0.5
583
- )
584
- attn_std = self.text_branch.width**-0.5
585
- fc_std = (2 * self.text_branch.width) ** -0.5
586
- for block in self.text_branch.resblocks:
587
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
588
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
589
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
590
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
591
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
592
- width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
593
- elif self.text_branch_type == "bart":
594
- width = self.text_branch.shared.weight.shape[-1]
595
- else:
596
- width = self.text_branch.width
597
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
598
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
599
-
600
- # deprecated
601
- # if hasattr(self.visual, 'init_parameters'):
602
- # self.visual.init_parameters()
603
-
604
- # if self.text_projection is not None:
605
- # nn.init.normal_(self.text_projection, std=width**-0.5)
606
-
607
- def build_attention_mask(self):
608
- # lazily create causal attention mask, with full attention between the vision tokens
609
- # pytorch uses additive attention mask; fill with -inf
610
- mask = torch.empty(self.context_length, self.context_length)
611
- mask.fill_(float("-inf"))
612
- mask.triu_(1) # zero out the lower diagonal
613
- return mask
614
-
615
- def encode_audio(self, audio, device):
616
- return self.audio_branch(
617
- audio, mixup_lambda=None, device=device
618
- ) # mix lambda needs to add
619
-
620
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
621
- # tmp = {}
622
- # for k in x[0].keys():
623
- # tmp[k] = []
624
- # for i in range(len(x)):
625
- # tmp[k].append(x[i][k][:77])
626
- # for k in x[0].keys():
627
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
628
- # return tmp
629
-
630
- def encode_text(self, text, device):
631
- if self.text_branch_type == "transformer":
632
- text = text.to(device=device, non_blocking=True)
633
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
634
-
635
- x = x + self.positional_embedding
636
- x = x.permute(1, 0, 2) # NLD -> LND
637
- x = self.text_branch(x, attn_mask=self.attn_mask)
638
- x = x.permute(1, 0, 2) # LND -> NLD
639
- x = self.ln_final(x)
640
-
641
- # x.shape = [batch_size, n_ctx, transformer.width]
642
- # take features from the eot embedding (eot_token is the highest number in each sequence)
643
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
644
- elif self.text_branch_type == "bert":
645
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
646
- # text = BatchEncoding(text)
647
- x = self.text_branch(
648
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
649
- attention_mask=text["attention_mask"].to(
650
- device=device, non_blocking=True
651
- ),
652
- token_type_ids=text["token_type_ids"].to(
653
- device=device, non_blocking=True
654
- ),
655
- )["pooler_output"]
656
- x = self.text_projection(x)
657
- elif self.text_branch_type == "roberta":
658
- x = self.text_branch(
659
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
660
- attention_mask=text["attention_mask"].to(
661
- device=device, non_blocking=True
662
- ),
663
- )["pooler_output"]
664
- x = self.text_projection(x)
665
- elif self.text_branch_type == "bart":
666
- x = torch.mean(
667
- self.text_branch(
668
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
669
- attention_mask=text["attention_mask"].to(
670
- device=device, non_blocking=True
671
- ),
672
- )["encoder_last_hidden_state"],
673
- axis=1,
674
- )
675
- x = self.text_projection(x)
676
- else:
677
- logging.error(f"Model type {self.text_branch_type} not found")
678
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
679
- return x
680
-
681
- def forward(self, audio, text, device=None):
682
- """Forward audio and text into the CLAP
683
-
684
- Parameters
685
- ----------
686
- audio: torch.Tensor (batch_size, audio_length)
687
- the time-domain audio input / the batch of mel_spec and longer list.
688
- text: torch.Tensor () // need to add
689
- the text token input
690
- """
691
- if device is None:
692
- if audio is not None:
693
- device = audio.device
694
- elif text is not None:
695
- device = text.device
696
- if audio is None and text is None:
697
- # a hack to get the logit scale
698
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
699
- elif audio is None:
700
- return self.encode_text(text, device=device)
701
- elif text is None:
702
- return self.audio_projection(
703
- self.encode_audio(audio, device=device)["embedding"]
704
- )
705
- audio_features = self.audio_projection(
706
- self.encode_audio(audio, device=device)["embedding"]
707
- )
708
- audio_features = F.normalize(audio_features, dim=-1)
709
-
710
- text_features = self.encode_text(text, device=device)
711
- # print("text_features", text_features)
712
- # print("text_features.shape", text_features.shape)
713
- # print("text_features.type", type(text_features))
714
- text_features = F.normalize(text_features, dim=-1)
715
-
716
- audio_features_mlp = self.audio_transform(audio_features)
717
- text_features_mlp = self.text_transform(text_features)
718
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
719
- return (
720
- audio_features,
721
- text_features,
722
- audio_features_mlp,
723
- text_features_mlp,
724
- self.logit_scale_a.exp(),
725
- self.logit_scale_t.exp(),
726
- )
727
-
728
- def get_logit_scale(self):
729
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
730
-
731
- def get_text_embedding(self, data):
732
- """Get the text embedding from the model
733
-
734
- Parameters
735
- ----------
736
- data: torch.Tensor
737
- a tensor of text embedding
738
-
739
- Returns
740
- ----------
741
- text_embed: torch.Tensor
742
- a tensor of text_embeds (N, D)
743
-
744
- """
745
- device = next(self.parameters()).device
746
- for k in data:
747
- data[k] = data[k].to(device)
748
- if(len(data[k].size()) < 2):
749
- data[k] = data[k].unsqueeze(0)
750
- text_embeds = self.encode_text(data, device=device)
751
- text_embeds = F.normalize(text_embeds, dim=-1)
752
-
753
- return text_embeds
754
-
755
- def get_audio_embedding(self, data):
756
- """Get the audio embedding from the model
757
-
758
- Parameters
759
- ----------
760
- data: a list of dict
761
- the audio input dict list from 'get_audio_feature' method
762
-
763
- Returns
764
- ----------
765
- audio_embed: torch.Tensor
766
- a tensor of audio_embeds (N, D)
767
-
768
- """
769
- device = next(self.parameters()).device
770
- input_dict = {}
771
- keys = data[0].keys()
772
- for k in keys:
773
- input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
774
- device
775
- )
776
-
777
- audio_embeds = self.audio_projection(
778
- self.encode_audio(input_dict, device=device)["embedding"]
779
- )
780
- audio_embeds = F.normalize(audio_embeds, dim=-1)
781
-
782
- return audio_embeds
783
-
784
- def audio_infer(self, audio, hopsize=None, device=None):
785
- """Forward one audio and produce the audio embedding
786
-
787
- Parameters
788
- ----------
789
- audio: (audio_length)
790
- the time-domain audio input, notice that it must be only one input
791
- hopsize: int
792
- the overlap hopsize as the sliding window
793
-
794
- Returns
795
- ----------
796
- output_dict: {
797
- key: [n, (embedding_shape)] if "HTS-AT"
798
- or
799
- key: [(embedding_shape)] if "PANN"
800
- }
801
- the list of key values of the audio branch
802
-
803
- """
804
-
805
- assert not self.training, "the inference mode must be run at eval stage"
806
- output_dict = {}
807
- # PANN
808
- if self.audio_cfg.model_type == "PANN":
809
- audio_input = audio.unsqueeze(dim=0)
810
- output_dict[key] = self.encode_audio(audio_input, device=device)[
811
- key
812
- ].squeeze(dim=0)
813
- elif self.audio_cfg.model_type == "HTSAT":
814
- # repeat
815
- audio_len = len(audio)
816
- k = self.audio_cfg.clip_samples // audio_len
817
- if k > 1:
818
- audio = audio.repeat(k)
819
- audio_len = len(audio)
820
-
821
- if hopsize is None:
822
- hopsize = min(hopsize, audio_len)
823
-
824
- if audio_len > self.audio_cfg.clip_samples:
825
- audio_input = [
826
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
827
- for pos in range(
828
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
829
- )
830
- ]
831
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
832
- audio_input = torch.stack(audio_input)
833
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
834
- else:
835
- audio_input = audio.unsqueeze(dim=0)
836
- output_dict[key] = self.encode_audio(audio_input, device=device)[
837
- key
838
- ].squeeze(dim=0)
839
-
840
- return output_dict
841
-
842
-
843
- def convert_weights_to_fp16(model: nn.Module):
844
- """Convert applicable model parameters to fp16"""
845
-
846
- def _convert_weights_to_fp16(l):
847
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
848
- l.weight.data = l.weight.data.half()
849
- if l.bias is not None:
850
- l.bias.data = l.bias.data.half()
851
-
852
- if isinstance(l, nn.MultiheadAttention):
853
- for attr in [
854
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
855
- "in_proj_bias",
856
- "bias_k",
857
- "bias_v",
858
- ]:
859
- tensor = getattr(l, attr)
860
- if tensor is not None:
861
- tensor.data = tensor.data.half()
862
-
863
- for name in ["text_projection", "proj"]:
864
- if hasattr(l, name):
865
- attr = getattr(l, name)
866
- if attr is not None:
867
- attr.data = attr.data.half()
868
-
869
- model.apply(_convert_weights_to_fp16)
870
-
871
-
872
- # Ignore the state dict of the vision part
873
- def build_model_from_openai_state_dict(
874
- state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
875
- ):
876
-
877
- embed_dim = model_cfg["embed_dim"]
878
- audio_cfg = model_cfg["audio_cfg"]
879
- text_cfg = model_cfg["text_cfg"]
880
- context_length = state_dict["positional_embedding"].shape[0]
881
- vocab_size = state_dict["token_embedding.weight"].shape[0]
882
- transformer_width = state_dict["ln_final.weight"].shape[0]
883
- transformer_heads = transformer_width // 64
884
- transformer_layers = len(
885
- set(
886
- k.split(".")[2]
887
- for k in state_dict
888
- if k.startswith(f"transformer.resblocks")
889
- )
890
- )
891
-
892
- audio_cfg = CLAPAudioCfp(**audio_cfg)
893
- text_cfg = CLAPTextCfg(**text_cfg)
894
-
895
- model = CLAP(
896
- embed_dim,
897
- audio_cfg=audio_cfg,
898
- text_cfg=text_cfg,
899
- quick_gelu=True, # OpenAI models were trained with QuickGELU
900
- enable_fusion=enable_fusion,
901
- fusion_type=fusion_type,
902
- )
903
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
904
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
905
- pop_keys = list(state_dict.keys())[::]
906
- # pop the visual branch saved weights
907
- for key in pop_keys:
908
- if key.startswith("visual."):
909
- state_dict.pop(key, None)
910
-
911
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
912
- state_dict.pop(key, None)
913
-
914
- # not use fp16
915
- # convert_weights_to_fp16(model)
916
- model.load_state_dict(state_dict, strict=False)
917
- return model.eval()
918
-
919
-
920
- def trace_model(model, batch_size=256, device=torch.device("cpu")):
921
- model.eval()
922
- audio_length = model.audio_cfg.audio_length
923
- example_audio = torch.ones((batch_size, audio_length), device=device)
924
- example_text = torch.zeros(
925
- (batch_size, model.context_length), dtype=torch.int, device=device
926
- )
927
- model = torch.jit.trace_module(
928
- model,
929
- inputs=dict(
930
- forward=(example_audio, example_text),
931
- encode_text=(example_text,),
932
- encode_image=(example_audio,),
933
- ),
934
- )
935
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
936
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-base.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "base"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-large.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "large"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/HTSAT-tiny.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-10.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn10"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 18000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 960000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 360,
10
- "fmin": 50,
11
- "fmax": 8000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 4
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-14.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/PANN-6.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn6"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN101-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 23,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN101.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 23,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 6,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm/clap/open_clip/model_configs/RN50.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 6,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }