Mudrock commited on
Commit
1b7174a
·
1 Parent(s): f3b8c25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +868 -79
app.py CHANGED
@@ -1,88 +1,877 @@
1
- import os
 
 
 
 
 
2
  import torch
3
- import librosa
4
- from open_clip import create_model
5
- from training.data import get_audio_features
6
- from training.data import int16_to_float32, float32_to_int16
7
- from transformers import RobertaTokenizer
8
-
9
- tokenize = RobertaTokenizer.from_pretrained('roberta-base')
10
- def tokenizer(text):
11
- result = tokenize(
12
- text,
13
- padding="max_length",
14
- truncation=True,
15
- max_length=77,
16
- return_tensors="pt",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
- return {k: v.squeeze(0) for k, v in result.items()}
19
-
20
- def infer_text():
21
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
22
- precision = 'fp32'
23
- amodel = 'HTSAT-tiny' # or 'PANN-14'
24
- tmodel = 'roberta' # the best text encoder in our training
25
- enable_fusion = True # False if you do not want to use the fusion model
26
- fusion_type = 'aff_2d'
27
- pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded.
28
-
29
- model, model_cfg = create_model(
30
- amodel,
31
- tmodel,
32
- pretrained,
33
- precision=precision,
34
- device=device,
35
- enable_fusion=enable_fusion,
36
- fusion_type=fusion_type
37
  )
38
- # load the text, can be a list (i.e. batch size)
39
- text_data = ["I love the contrastive learning", "I love the pretrain model"]
40
- # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90
41
- text_data = tokenizer(text_data)
42
-
43
- text_embed = model.get_text_embedding(text_data)
44
- print(text_embed.size())
45
-
46
- def infer_audio():
47
-
48
- device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
49
- precision = 'fp32'
50
- amodel = 'HTSAT-tiny' # or 'PANN-14'
51
- tmodel = 'roberta' # the best text encoder in our training
52
- enable_fusion = True # False if you do not want to use the fusion model
53
- fusion_type = 'aff_2d'
54
- pretrained = "/home/la/kechen/Research/KE_CLAP/ckpt/fusion_best.pt" # the checkpoint name, the unfusion model can also be loaded.
55
-
56
- model, model_cfg = create_model(
57
- amodel,
58
- tmodel,
59
- pretrained,
60
- precision=precision,
61
- device=device,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  enable_fusion=enable_fusion,
63
  fusion_type=fusion_type
64
  )
 
 
 
 
 
 
 
65
 
66
- # load the waveform of the shape (T,), should resample to 48000
67
- audio_waveform, sr = librosa.load('/home/la/kechen/Research/KE_CLAP/ckpt/test_clap_long.wav', sr=48000)
68
- # quantize
69
- audio_waveform = int16_to_float32(float32_to_int16(audio_waveform))
70
- audio_waveform = torch.from_numpy(audio_waveform).float()
71
- audio_dict = {}
72
-
73
- # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode
74
- audio_dict = get_audio_features(
75
- audio_dict, audio_waveform, 480000,
76
- data_truncating='fusion',
77
- data_filling='repeatpad',
78
- audio_cfg=model_cfg['audio_cfg']
79
- )
80
- # can send a list to the model, to process many audio tracks in one time (i.e. batch size)
81
- audio_embed = model.get_audio_embedding([audio_dict])
82
- print(audio_embed.size())
83
-
84
 
 
 
 
 
85
 
86
- if __name__ == "__main__":
87
- infer_text()
88
- pip install torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from dataclasses import dataclass
3
+ from email.mime import audio
4
+ from typing import Tuple, Union, Callable, Optional
5
+
6
+ import numpy as np
7
  import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+
11
+ from .timm_model import TimmModel
12
+ import logging
13
+ from .utils import freeze_batch_norm_2d
14
+
15
+ from .pann_model import create_pann_model
16
+ from .htsat import create_htsat_model
17
+ from transformers import BertModel, RobertaModel, BartModel
18
+ from transformers.tokenization_utils_base import BatchEncoding
19
+
20
+
21
+ class MLPLayers(nn.Module):
22
+ def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
23
+ super(MLPLayers, self).__init__()
24
+ self.nonlin = nonlin
25
+ self.dropout = dropout
26
+
27
+ sequence = []
28
+ for u0, u1 in zip(units[:-1], units[1:]):
29
+ sequence.append(nn.Linear(u0, u1))
30
+ sequence.append(self.nonlin)
31
+ sequence.append(nn.Dropout(self.dropout))
32
+ sequence = sequence[:-2]
33
+
34
+ self.sequential = nn.Sequential(*sequence)
35
+
36
+ def forward(self, X):
37
+ X = self.sequential(X)
38
+ return X
39
+
40
+
41
+ class Bottleneck(nn.Module):
42
+ expansion = 4
43
+
44
+ def __init__(self, inplanes, planes, stride=1):
45
+ super().__init__()
46
+
47
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
48
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
49
+ self.bn1 = nn.BatchNorm2d(planes)
50
+
51
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
52
+ self.bn2 = nn.BatchNorm2d(planes)
53
+
54
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
55
+
56
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
57
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
58
+
59
+ self.relu = nn.ReLU(inplace=True)
60
+ self.downsample = None
61
+ self.stride = stride
62
+
63
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
64
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
65
+ self.downsample = nn.Sequential(
66
+ OrderedDict(
67
+ [
68
+ ("-1", nn.AvgPool2d(stride)),
69
+ (
70
+ "0",
71
+ nn.Conv2d(
72
+ inplanes,
73
+ planes * self.expansion,
74
+ 1,
75
+ stride=1,
76
+ bias=False,
77
+ ),
78
+ ),
79
+ ("1", nn.BatchNorm2d(planes * self.expansion)),
80
+ ]
81
+ )
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor):
85
+ identity = x
86
+
87
+ out = self.relu(self.bn1(self.conv1(x)))
88
+ out = self.relu(self.bn2(self.conv2(out)))
89
+ out = self.avgpool(out)
90
+ out = self.bn3(self.conv3(out))
91
+
92
+ if self.downsample is not None:
93
+ identity = self.downsample(x)
94
+
95
+ out += identity
96
+ out = self.relu(out)
97
+ return out
98
+
99
+
100
+ class AttentionPool2d(nn.Module):
101
+ def __init__(
102
+ self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
103
+ ):
104
+ super().__init__()
105
+ self.positional_embedding = nn.Parameter(
106
+ torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
107
+ )
108
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
109
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
110
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
111
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
112
+ self.num_heads = num_heads
113
+
114
+ def forward(self, x):
115
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
116
+ 2, 0, 1
117
+ ) # NCHW -> (HW)NC
118
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
119
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
120
+ x, _ = F.multi_head_attention_forward(
121
+ query=x,
122
+ key=x,
123
+ value=x,
124
+ embed_dim_to_check=x.shape[-1],
125
+ num_heads=self.num_heads,
126
+ q_proj_weight=self.q_proj.weight,
127
+ k_proj_weight=self.k_proj.weight,
128
+ v_proj_weight=self.v_proj.weight,
129
+ in_proj_weight=None,
130
+ in_proj_bias=torch.cat(
131
+ [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
132
+ ),
133
+ bias_k=None,
134
+ bias_v=None,
135
+ add_zero_attn=False,
136
+ dropout_p=0,
137
+ out_proj_weight=self.c_proj.weight,
138
+ out_proj_bias=self.c_proj.bias,
139
+ use_separate_proj_weight=True,
140
+ training=self.training,
141
+ need_weights=False,
142
+ )
143
+
144
+ return x[0]
145
+
146
+
147
+ class ModifiedResNet(nn.Module):
148
+ """
149
+ A ResNet class that is similar to torchvision's but contains the following changes:
150
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
151
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
152
+ - The final pooling layer is a QKV attention instead of an average pool
153
+ """
154
+
155
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
156
+ super().__init__()
157
+ self.output_dim = output_dim
158
+ self.image_size = image_size
159
+
160
+ # the 3-layer stem
161
+ self.conv1 = nn.Conv2d(
162
+ 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
163
+ )
164
+ self.bn1 = nn.BatchNorm2d(width // 2)
165
+ self.conv2 = nn.Conv2d(
166
+ width // 2, width // 2, kernel_size=3, padding=1, bias=False
167
+ )
168
+ self.bn2 = nn.BatchNorm2d(width // 2)
169
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
170
+ self.bn3 = nn.BatchNorm2d(width)
171
+ self.avgpool = nn.AvgPool2d(2)
172
+ self.relu = nn.ReLU(inplace=True)
173
+
174
+ # residual layers
175
+ self._inplanes = width # this is a *mutable* variable used during construction
176
+ self.layer1 = self._make_layer(width, layers[0])
177
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
178
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
179
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
180
+
181
+ embed_dim = width * 32 # the ResNet feature dimension
182
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
183
+
184
+ self.init_parameters()
185
+
186
+ def _make_layer(self, planes, blocks, stride=1):
187
+ layers = [Bottleneck(self._inplanes, planes, stride)]
188
+
189
+ self._inplanes = planes * Bottleneck.expansion
190
+ for _ in range(1, blocks):
191
+ layers.append(Bottleneck(self._inplanes, planes))
192
+
193
+ return nn.Sequential(*layers)
194
+
195
+ def init_parameters(self):
196
+ if self.attnpool is not None:
197
+ std = self.attnpool.c_proj.in_features**-0.5
198
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
199
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
200
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
201
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
202
+
203
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
204
+ for name, param in resnet_block.named_parameters():
205
+ if name.endswith("bn3.weight"):
206
+ nn.init.zeros_(param)
207
+
208
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
209
+ assert (
210
+ unlocked_groups == 0
211
+ ), "partial locking not currently supported for this model"
212
+ for param in self.parameters():
213
+ param.requires_grad = False
214
+ if freeze_bn_stats:
215
+ freeze_batch_norm_2d(self)
216
+
217
+ def stem(self, x):
218
+ for conv, bn in [
219
+ (self.conv1, self.bn1),
220
+ (self.conv2, self.bn2),
221
+ (self.conv3, self.bn3),
222
+ ]:
223
+ x = self.relu(bn(conv(x)))
224
+ x = self.avgpool(x)
225
+ return x
226
+
227
+ def forward(self, x):
228
+ x = self.stem(x)
229
+ x = self.layer1(x)
230
+ x = self.layer2(x)
231
+ x = self.layer3(x)
232
+ x = self.layer4(x)
233
+ x = self.attnpool(x)
234
+
235
+ return x
236
+
237
+
238
+ class LayerNorm(nn.LayerNorm):
239
+ """Subclass torch's LayerNorm to handle fp16."""
240
+
241
+ def forward(self, x: torch.Tensor):
242
+ orig_type = x.dtype
243
+ x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
244
+ return x.to(orig_type)
245
+
246
+
247
+ class QuickGELU(nn.Module):
248
+ # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
249
+ def forward(self, x: torch.Tensor):
250
+ return x * torch.sigmoid(1.702 * x)
251
+
252
+
253
+ class ResidualAttentionBlock(nn.Module):
254
+ def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
255
+ super().__init__()
256
+
257
+ self.attn = nn.MultiheadAttention(d_model, n_head)
258
+ self.ln_1 = LayerNorm(d_model)
259
+ self.mlp = nn.Sequential(
260
+ OrderedDict(
261
+ [
262
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
263
+ ("gelu", act_layer()),
264
+ ("c_proj", nn.Linear(d_model * 4, d_model)),
265
+ ]
266
+ )
267
+ )
268
+ self.ln_2 = LayerNorm(d_model)
269
+
270
+ def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
271
+ return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
272
+
273
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
+ x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
275
+ x = x + self.mlp(self.ln_2(x))
276
+ return x
277
+
278
+
279
+ class Transformer(nn.Module):
280
+ def __init__(
281
+ self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
282
+ ):
283
+ super().__init__()
284
+ self.width = width
285
+ self.layers = layers
286
+ self.resblocks = nn.ModuleList(
287
+ [
288
+ ResidualAttentionBlock(width, heads, act_layer=act_layer)
289
+ for _ in range(layers)
290
+ ]
291
+ )
292
+
293
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
294
+ for r in self.resblocks:
295
+ x = r(x, attn_mask=attn_mask)
296
+ return x
297
+
298
+
299
+ class VisualTransformer(nn.Module):
300
+ def __init__(
301
+ self,
302
+ image_size: int,
303
+ patch_size: int,
304
+ width: int,
305
+ layers: int,
306
+ heads: int,
307
+ output_dim: int,
308
+ act_layer: Callable = nn.GELU,
309
+ ):
310
+ super().__init__()
311
+ self.image_size = image_size
312
+ self.output_dim = output_dim
313
+ self.conv1 = nn.Conv2d(
314
+ in_channels=3,
315
+ out_channels=width,
316
+ kernel_size=patch_size,
317
+ stride=patch_size,
318
+ bias=False,
319
+ )
320
+
321
+ scale = width**-0.5
322
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
323
+ self.positional_embedding = nn.Parameter(
324
+ scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
325
+ )
326
+ self.ln_pre = LayerNorm(width)
327
+
328
+ self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
329
+
330
+ self.ln_post = LayerNorm(width)
331
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
332
+
333
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
334
+ assert (
335
+ unlocked_groups == 0
336
+ ), "partial locking not currently supported for this model"
337
+ for param in self.parameters():
338
+ param.requires_grad = False
339
+
340
+ def forward(self, x: torch.Tensor):
341
+ x = self.conv1(x) # shape = [*, width, grid, grid]
342
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
343
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
344
+ x = torch.cat(
345
+ [
346
+ self.class_embedding.to(x.dtype)
347
+ + torch.zeros(
348
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
349
+ ),
350
+ x,
351
+ ],
352
+ dim=1,
353
+ ) # shape = [*, grid ** 2 + 1, width]
354
+ x = x + self.positional_embedding.to(x.dtype)
355
+ x = self.ln_pre(x)
356
+
357
+ x = x.permute(1, 0, 2) # NLD -> LND
358
+ x = self.text_branch(x)
359
+ x = x.permute(1, 0, 2) # LND -> NLD
360
+
361
+ x = self.ln_post(x[:, 0, :])
362
+
363
+ if self.proj is not None:
364
+ x = x @ self.proj
365
+
366
+ return x
367
+
368
+
369
+ @dataclass
370
+ class CLAPVisionCfg:
371
+ layers: Union[Tuple[int, int, int, int], int] = 12
372
+ width: int = 768
373
+ patch_size: int = 16
374
+ image_size: Union[Tuple[int, int], int] = 224
375
+ timm_model_name: str = (
376
+ None # a valid model name overrides layers, width, patch_size
377
  )
378
+ timm_model_pretrained: bool = (
379
+ False # use (imagenet) pretrained weights for named model
380
+ )
381
+ timm_pool: str = (
382
+ "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
383
+ )
384
+ timm_proj: str = (
385
+ "linear" # linear projection for timm model output ('linear', 'mlp', '')
 
 
 
 
 
 
 
 
 
 
 
386
  )
387
+
388
+
389
+ # Audio Config Class
390
+ @dataclass
391
+ class CLAPAudioCfp:
392
+ model_type: str = "PANN"
393
+ model_name: str = "Cnn14"
394
+ sample_rate: int = 48000
395
+ # Param
396
+ audio_length: int = 1024
397
+ window_size: int = 1024
398
+ hop_size: int = 1024
399
+ fmin: int = 50
400
+ fmax: int = 14000
401
+ class_num: int = 527
402
+ mel_bins: int = 64
403
+ clip_samples: int = 480000
404
+
405
+
406
+ @dataclass
407
+ class CLAPTextCfg:
408
+ context_length: int
409
+ vocab_size: int
410
+ width: int
411
+ heads: int
412
+ layers: int
413
+ model_type: str
414
+
415
+
416
+ class CLAP(nn.Module):
417
+ def __init__(
418
+ self,
419
+ embed_dim: int,
420
+ audio_cfg: CLAPAudioCfp,
421
+ text_cfg: CLAPTextCfg,
422
+ quick_gelu: bool = False,
423
+ enable_fusion: bool = False,
424
+ fusion_type: str = 'None',
425
+ joint_embed_shape: int = 512,
426
+ mlp_act: str = 'relu',
427
+ ):
428
+ super().__init__()
429
+ if isinstance(audio_cfg, dict):
430
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
431
+ if isinstance(text_cfg, dict):
432
+ text_cfg = CLAPTextCfg(**text_cfg)
433
+
434
+ self.audio_cfg = audio_cfg
435
+ self.text_cfg = text_cfg
436
+ self.enable_fusion = enable_fusion
437
+ self.fusion_type = fusion_type
438
+ self.joint_embed_shape = joint_embed_shape
439
+ self.mlp_act = mlp_act
440
+
441
+
442
+ self.context_length = text_cfg.context_length
443
+
444
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
445
+ # memory efficient in recent PyTorch releases (>= 1.10).
446
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
447
+ act_layer = QuickGELU if quick_gelu else nn.GELU
448
+
449
+ if mlp_act == 'relu':
450
+ mlp_act_layer = nn.ReLU()
451
+ elif mlp_act == 'gelu':
452
+ mlp_act_layer = nn.GELU()
453
+ else:
454
+ raise NotImplementedError
455
+
456
+ # audio branch
457
+ # audio branch parameters
458
+ if audio_cfg.model_type == "PANN":
459
+ self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
460
+ elif audio_cfg.model_type == "HTSAT":
461
+ self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type)
462
+ else:
463
+ logging.error(f"Model config for {audio_cfg.model_type} not found")
464
+ raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
465
+
466
+ # text branch
467
+ # text branch parameters
468
+ if text_cfg.model_type == "transformer":
469
+ self.text_branch = Transformer(
470
+ width=text_cfg.width,
471
+ layers=text_cfg.layers,
472
+ heads=text_cfg.heads,
473
+ act_layer=act_layer,
474
+ )
475
+ self.vocab_size = text_cfg.vocab_size
476
+ self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
477
+ self.positional_embedding = nn.Parameter(
478
+ torch.empty(self.context_length, text_cfg.width)
479
+ )
480
+ self.ln_final = LayerNorm(text_cfg.width)
481
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
482
+ self.joint_embed_shape,
483
+ self.joint_embed_shape], dropout=0.1)
484
+ self.text_projection = nn.Sequential(
485
+ nn.Linear(text_cfg.width, self.joint_embed_shape),
486
+ mlp_act_layer,
487
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
488
+ )
489
+ elif text_cfg.model_type == "bert":
490
+ self.text_branch = BertModel.from_pretrained("bert-base-uncased")
491
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
492
+ self.joint_embed_shape,
493
+ self.joint_embed_shape], dropout=0.1)
494
+ self.text_projection = nn.Sequential(
495
+ nn.Linear(768, self.joint_embed_shape),
496
+ mlp_act_layer,
497
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
498
+ )
499
+ elif text_cfg.model_type == "roberta":
500
+ self.text_branch = RobertaModel.from_pretrained('roberta-base')
501
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
502
+ self.joint_embed_shape,
503
+ self.joint_embed_shape], dropout=0.1)
504
+ self.text_projection = nn.Sequential(
505
+ nn.Linear(768, self.joint_embed_shape),
506
+ mlp_act_layer,
507
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
508
+ )
509
+ elif text_cfg.model_type == "bart":
510
+ self.text_branch = BartModel.from_pretrained('facebook/bart-base')
511
+ self.text_transform = MLPLayers(units=[self.joint_embed_shape,
512
+ self.joint_embed_shape,
513
+ self.joint_embed_shape], dropout=0.1)
514
+ self.text_projection = nn.Sequential(
515
+ nn.Linear(768, self.joint_embed_shape),
516
+ mlp_act_layer,
517
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
518
+ )
519
+ else:
520
+ logging.error(f"Model config for {text_cfg.model_type} not found")
521
+ raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
522
+ self.text_branch_type = text_cfg.model_type
523
+ # text branch parameters
524
+
525
+ # audio branch parameters
526
+ self.audio_transform = MLPLayers(units=[self.joint_embed_shape,
527
+ self.joint_embed_shape,
528
+ self.joint_embed_shape], dropout=0.1)
529
+
530
+ # below here is text branch parameters
531
+
532
+ # ============================================================================================================
533
+ self.audio_projection = nn.Sequential(
534
+ nn.Linear(embed_dim, self.joint_embed_shape),
535
+ mlp_act_layer,
536
+ nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
537
+ )
538
+
539
+ self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
540
+ self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
541
+ self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
542
+
543
+ self.init_text_branch_parameters()
544
+
545
+ def init_text_branch_parameters(self):
546
+ if self.text_branch_type == "transformer":
547
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
548
+ nn.init.normal_(self.positional_embedding, std=0.01)
549
+ proj_std = (self.text_branch.width**-0.5) * (
550
+ (2 * self.text_branch.layers) ** -0.5
551
+ )
552
+ attn_std = self.text_branch.width**-0.5
553
+ fc_std = (2 * self.text_branch.width) ** -0.5
554
+ for block in self.text_branch.resblocks:
555
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
556
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
557
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
558
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
559
+ if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
560
+ width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
561
+ elif self.text_branch_type == "bart":
562
+ width = self.text_branch.shared.weight.shape[-1]
563
+ else:
564
+ width = self.text_branch.width
565
+ nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
566
+ nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
567
+
568
+ # deprecated
569
+ # if hasattr(self.visual, 'init_parameters'):
570
+ # self.visual.init_parameters()
571
+
572
+ # if self.text_projection is not None:
573
+ # nn.init.normal_(self.text_projection, std=width**-0.5)
574
+
575
+ def build_attention_mask(self):
576
+ # lazily create causal attention mask, with full attention between the vision tokens
577
+ # pytorch uses additive attention mask; fill with -inf
578
+ mask = torch.empty(self.context_length, self.context_length)
579
+ mask.fill_(float("-inf"))
580
+ mask.triu_(1) # zero out the lower diagonal
581
+ return mask
582
+
583
+ def encode_audio(self, audio, device):
584
+ return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add
585
+
586
+ # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
587
+ # tmp = {}
588
+ # for k in x[0].keys():
589
+ # tmp[k] = []
590
+ # for i in range(len(x)):
591
+ # tmp[k].append(x[i][k][:77])
592
+ # for k in x[0].keys():
593
+ # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
594
+ # return tmp
595
+
596
+ def encode_text(self, text, device):
597
+ if self.text_branch_type == "transformer":
598
+ text = text.to(device=device, non_blocking=True)
599
+ x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
600
+
601
+ x = x + self.positional_embedding
602
+ x = x.permute(1, 0, 2) # NLD -> LND
603
+ x = self.text_branch(x, attn_mask=self.attn_mask)
604
+ x = x.permute(1, 0, 2) # LND -> NLD
605
+ x = self.ln_final(x)
606
+
607
+ # x.shape = [batch_size, n_ctx, transformer.width]
608
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
609
+ x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
610
+ elif self.text_branch_type == "bert":
611
+ # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
612
+ # text = BatchEncoding(text)
613
+ x = self.text_branch(
614
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
615
+ attention_mask=text["attention_mask"].to(
616
+ device=device, non_blocking=True
617
+ ),
618
+ token_type_ids=text["token_type_ids"].to(
619
+ device=device, non_blocking=True
620
+ ),
621
+ )["pooler_output"]
622
+ x = self.text_projection(x)
623
+ elif self.text_branch_type == "roberta":
624
+ x = self.text_branch(
625
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
626
+ attention_mask=text["attention_mask"].to(
627
+ device=device, non_blocking=True
628
+ ),
629
+ )["pooler_output"]
630
+ x = self.text_projection(x)
631
+ elif self.text_branch_type == "bart":
632
+ x = torch.mean(self.text_branch(
633
+ input_ids=text["input_ids"].to(device=device, non_blocking=True),
634
+ attention_mask=text["attention_mask"].to(
635
+ device=device, non_blocking=True
636
+ ),
637
+ )["encoder_last_hidden_state"],axis=1)
638
+ x = self.text_projection(x)
639
+ else:
640
+ logging.error(f"Model type {self.text_branch_type} not found")
641
+ raise RuntimeError(f"Model type {self.text_branch_type} not found.")
642
+ return x
643
+
644
+ def forward(self, audio, text, device=None):
645
+ """Forward audio and text into the CLAP
646
+ Parameters
647
+ ----------
648
+ audio: torch.Tensor (batch_size, audio_length)
649
+ the time-domain audio input / the batch of mel_spec and longer list.
650
+ text: torch.Tensor () // need to add
651
+ the text token input
652
+ """
653
+ if device is None:
654
+ if audio is not None:
655
+ device = audio.device
656
+ elif text is not None:
657
+ device = text.device
658
+ if audio is None and text is None:
659
+ # a hack to get the logit scale
660
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
661
+ elif audio is None:
662
+ return self.encode_text(text, device=device)
663
+ elif text is None:
664
+ return self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
665
+ audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
666
+ audio_features = F.normalize(audio_features, dim=-1)
667
+
668
+ text_features = self.encode_text(
669
+ text, device=device
670
+ )
671
+ # print("text_features", text_features)
672
+ # print("text_features.shape", text_features.shape)
673
+ # print("text_features.type", type(text_features))
674
+ text_features = F.normalize(text_features, dim=-1)
675
+
676
+ audio_features_mlp = self.audio_transform(audio_features)
677
+ text_features_mlp = self.text_transform(text_features)
678
+ # Four outputs: audio features (basic & MLP), text features (basic & MLP)
679
+ return (
680
+ audio_features,
681
+ text_features,
682
+ audio_features_mlp,
683
+ text_features_mlp,
684
+ self.logit_scale_a.exp(),
685
+ self.logit_scale_t.exp(),
686
+ )
687
+
688
+ def get_logit_scale(self):
689
+ return self.logit_scale_a.exp(), self.logit_scale_t.exp()
690
+
691
+ def get_text_embedding(self, data):
692
+ """Get the text embedding from the model
693
+ Parameters
694
+ ----------
695
+ data: torch.Tensor
696
+ a tensor of text embedding
697
+ Returns
698
+ ----------
699
+ text_embed: torch.Tensor
700
+ a tensor of text_embeds (N, D)
701
+ """
702
+ device = next(self.parameters()).device
703
+ for k in data:
704
+ data[k] = data[k].to(device)
705
+ text_embeds = self.encode_text(data, device=device)
706
+ text_embeds = F.normalize(text_embeds, dim=-1)
707
+
708
+ return text_embeds
709
+
710
+ def get_audio_embedding(self, data):
711
+ """Get the audio embedding from the model
712
+ Parameters
713
+ ----------
714
+ data: a list of dict
715
+ the audio input dict list from 'get_audio_feature' method
716
+ Returns
717
+ ----------
718
+ audio_embed: torch.Tensor
719
+ a tensor of audio_embeds (N, D)
720
+ """
721
+ device = next(self.parameters()).device
722
+ input_dict = {}
723
+ keys = data[0].keys()
724
+ for k in keys:
725
+ input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
726
+
727
+ audio_embeds = self.audio_projection(self.encode_audio(input_dict, device=device)["embedding"])
728
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
729
+
730
+ return audio_embeds
731
+
732
+
733
+
734
+ def audio_infer(self, audio, hopsize=None, device=None):
735
+ """Forward one audio and produce the audio embedding
736
+ Parameters
737
+ ----------
738
+ audio: (audio_length)
739
+ the time-domain audio input, notice that it must be only one input
740
+ hopsize: int
741
+ the overlap hopsize as the sliding window
742
+ Returns
743
+ ----------
744
+ output_dict: {
745
+ key: [n, (embedding_shape)] if "HTS-AT"
746
+ or
747
+ key: [(embedding_shape)] if "PANN"
748
+ }
749
+ the list of key values of the audio branch
750
+ """
751
+
752
+ assert not self.training, "the inference mode must be run at eval stage"
753
+ output_dict = {}
754
+ # PANN
755
+ if self.audio_cfg.model_type == "PANN":
756
+ audio_input = audio.unsqueeze(dim=0)
757
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
758
+ elif self.audio_cfg.model_type == "HTSAT":
759
+ # repeat
760
+ audio_len = len(audio)
761
+ k = self.audio_cfg.clip_samples // audio_len
762
+ if k > 1:
763
+ audio = audio.repeat(k)
764
+ audio_len = len(audio)
765
+
766
+ if hopsize is None:
767
+ hopsize = min(hopsize, audio_len)
768
+
769
+ if audio_len > self.audio_cfg.clip_samples:
770
+ audio_input = [
771
+ audio[pos : pos + self.audio_cfg.clip_samples].clone()
772
+ for pos in range(
773
+ 0, audio_len - self.audio_cfg.clip_samples, hopsize
774
+ )
775
+ ]
776
+ audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
777
+ audio_input = torch.stack(audio_input)
778
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key]
779
+ else:
780
+ audio_input = audio.unsqueeze(dim=0)
781
+ output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
782
+
783
+ return output_dict
784
+
785
+
786
+ def convert_weights_to_fp16(model: nn.Module):
787
+ """Convert applicable model parameters to fp16"""
788
+
789
+ def _convert_weights_to_fp16(l):
790
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
791
+ l.weight.data = l.weight.data.half()
792
+ if l.bias is not None:
793
+ l.bias.data = l.bias.data.half()
794
+
795
+ if isinstance(l, nn.MultiheadAttention):
796
+ for attr in [
797
+ *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
798
+ "in_proj_bias",
799
+ "bias_k",
800
+ "bias_v",
801
+ ]:
802
+ tensor = getattr(l, attr)
803
+ if tensor is not None:
804
+ tensor.data = tensor.data.half()
805
+
806
+ for name in ["text_projection", "proj"]:
807
+ if hasattr(l, name):
808
+ attr = getattr(l, name)
809
+ if attr is not None:
810
+ attr.data = attr.data.half()
811
+
812
+ model.apply(_convert_weights_to_fp16)
813
+
814
+
815
+ # Ignore the state dict of the vision part
816
+ def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'):
817
+
818
+ embed_dim = model_cfg["embed_dim"]
819
+ audio_cfg = model_cfg["audio_cfg"]
820
+ text_cfg = model_cfg["text_cfg"]
821
+ context_length = state_dict["positional_embedding"].shape[0]
822
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
823
+ transformer_width = state_dict["ln_final.weight"].shape[0]
824
+ transformer_heads = transformer_width // 64
825
+ transformer_layers = len(
826
+ set(
827
+ k.split(".")[2]
828
+ for k in state_dict
829
+ if k.startswith(f"transformer.resblocks")
830
+ )
831
+ )
832
+
833
+ audio_cfg = CLAPAudioCfp(**audio_cfg)
834
+ text_cfg = CLAPTextCfg(**text_cfg)
835
+
836
+ model = CLAP(
837
+ embed_dim,
838
+ audio_cfg=audio_cfg,
839
+ text_cfg=text_cfg,
840
+ quick_gelu=True, # OpenAI models were trained with QuickGELU
841
  enable_fusion=enable_fusion,
842
  fusion_type=fusion_type
843
  )
844
+ state_dict["logit_scale_a"] = state_dict["logit_scale"]
845
+ state_dict["logit_scale_t"] = state_dict["logit_scale"]
846
+ pop_keys = list(state_dict.keys())[::]
847
+ # pop the visual branch saved weights
848
+ for key in pop_keys:
849
+ if key.startswith("visual."):
850
+ state_dict.pop(key, None)
851
 
852
+ for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
853
+ state_dict.pop(key, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
 
855
+ # not use fp16
856
+ # convert_weights_to_fp16(model)
857
+ model.load_state_dict(state_dict, strict=False)
858
+ return model.eval()
859
 
860
+
861
+ def trace_model(model, batch_size=256, device=torch.device("cpu")):
862
+ model.eval()
863
+ audio_length = model.audio_cfg.audio_length
864
+ example_audio = torch.ones((batch_size, audio_length), device=device)
865
+ example_text = torch.zeros(
866
+ (batch_size, model.context_length), dtype=torch.int, device=device
867
+ )
868
+ model = torch.jit.trace_module(
869
+ model,
870
+ inputs=dict(
871
+ forward=(example_audio, example_text),
872
+ encode_text=(example_text,),
873
+ encode_image=(example_audio,),
874
+ ),
875
+ )
876
+ model.audio_cfg.audio_length = audio_length # Question: what does this do?
877
+ return model