Mudrock commited on
Commit
76d4671
·
1 Parent(s): 691ed0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -872
app.py CHANGED
@@ -1,878 +1,20 @@
1
- from collections import OrderedDict
2
- from dataclasses import dataclass
3
- from email.mime import audio
4
- from typing import Tuple, Union, Callable, Optional
5
 
6
- import numpy as np
7
- !pip install torch
8
- import torch
9
- import torch.nn.functional as F
10
- from torch import nn
11
 
12
- from .timm_model import TimmModel
13
- import logging
14
- from .utils import freeze_batch_norm_2d
15
 
16
- from .pann_model import create_pann_model
17
- from .htsat import create_htsat_model
18
- from transformers import BertModel, RobertaModel, BartModel
19
- from transformers.tokenization_utils_base import BatchEncoding
20
 
 
 
 
 
 
21
 
22
- class MLPLayers(nn.Module):
23
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
24
- super(MLPLayers, self).__init__()
25
- self.nonlin = nonlin
26
- self.dropout = dropout
27
 
28
- sequence = []
29
- for u0, u1 in zip(units[:-1], units[1:]):
30
- sequence.append(nn.Linear(u0, u1))
31
- sequence.append(self.nonlin)
32
- sequence.append(nn.Dropout(self.dropout))
33
- sequence = sequence[:-2]
34
-
35
- self.sequential = nn.Sequential(*sequence)
36
-
37
- def forward(self, X):
38
- X = self.sequential(X)
39
- return X
40
-
41
-
42
- class Bottleneck(nn.Module):
43
- expansion = 4
44
-
45
- def __init__(self, inplanes, planes, stride=1):
46
- super().__init__()
47
-
48
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
49
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
50
- self.bn1 = nn.BatchNorm2d(planes)
51
-
52
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
53
- self.bn2 = nn.BatchNorm2d(planes)
54
-
55
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
56
-
57
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
58
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
59
-
60
- self.relu = nn.ReLU(inplace=True)
61
- self.downsample = None
62
- self.stride = stride
63
-
64
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
65
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
66
- self.downsample = nn.Sequential(
67
- OrderedDict(
68
- [
69
- ("-1", nn.AvgPool2d(stride)),
70
- (
71
- "0",
72
- nn.Conv2d(
73
- inplanes,
74
- planes * self.expansion,
75
- 1,
76
- stride=1,
77
- bias=False,
78
- ),
79
- ),
80
- ("1", nn.BatchNorm2d(planes * self.expansion)),
81
- ]
82
- )
83
- )
84
-
85
- def forward(self, x: torch.Tensor):
86
- identity = x
87
-
88
- out = self.relu(self.bn1(self.conv1(x)))
89
- out = self.relu(self.bn2(self.conv2(out)))
90
- out = self.avgpool(out)
91
- out = self.bn3(self.conv3(out))
92
-
93
- if self.downsample is not None:
94
- identity = self.downsample(x)
95
-
96
- out += identity
97
- out = self.relu(out)
98
- return out
99
-
100
-
101
- class AttentionPool2d(nn.Module):
102
- def __init__(
103
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
104
- ):
105
- super().__init__()
106
- self.positional_embedding = nn.Parameter(
107
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
108
- )
109
- self.k_proj = nn.Linear(embed_dim, embed_dim)
110
- self.q_proj = nn.Linear(embed_dim, embed_dim)
111
- self.v_proj = nn.Linear(embed_dim, embed_dim)
112
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
113
- self.num_heads = num_heads
114
-
115
- def forward(self, x):
116
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
117
- 2, 0, 1
118
- ) # NCHW -> (HW)NC
119
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
120
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
121
- x, _ = F.multi_head_attention_forward(
122
- query=x,
123
- key=x,
124
- value=x,
125
- embed_dim_to_check=x.shape[-1],
126
- num_heads=self.num_heads,
127
- q_proj_weight=self.q_proj.weight,
128
- k_proj_weight=self.k_proj.weight,
129
- v_proj_weight=self.v_proj.weight,
130
- in_proj_weight=None,
131
- in_proj_bias=torch.cat(
132
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
133
- ),
134
- bias_k=None,
135
- bias_v=None,
136
- add_zero_attn=False,
137
- dropout_p=0,
138
- out_proj_weight=self.c_proj.weight,
139
- out_proj_bias=self.c_proj.bias,
140
- use_separate_proj_weight=True,
141
- training=self.training,
142
- need_weights=False,
143
- )
144
-
145
- return x[0]
146
-
147
-
148
- class ModifiedResNet(nn.Module):
149
- """
150
- A ResNet class that is similar to torchvision's but contains the following changes:
151
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
152
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
153
- - The final pooling layer is a QKV attention instead of an average pool
154
- """
155
-
156
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
157
- super().__init__()
158
- self.output_dim = output_dim
159
- self.image_size = image_size
160
-
161
- # the 3-layer stem
162
- self.conv1 = nn.Conv2d(
163
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
164
- )
165
- self.bn1 = nn.BatchNorm2d(width // 2)
166
- self.conv2 = nn.Conv2d(
167
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
168
- )
169
- self.bn2 = nn.BatchNorm2d(width // 2)
170
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
171
- self.bn3 = nn.BatchNorm2d(width)
172
- self.avgpool = nn.AvgPool2d(2)
173
- self.relu = nn.ReLU(inplace=True)
174
-
175
- # residual layers
176
- self._inplanes = width # this is a *mutable* variable used during construction
177
- self.layer1 = self._make_layer(width, layers[0])
178
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
179
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
180
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
181
-
182
- embed_dim = width * 32 # the ResNet feature dimension
183
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
184
-
185
- self.init_parameters()
186
-
187
- def _make_layer(self, planes, blocks, stride=1):
188
- layers = [Bottleneck(self._inplanes, planes, stride)]
189
-
190
- self._inplanes = planes * Bottleneck.expansion
191
- for _ in range(1, blocks):
192
- layers.append(Bottleneck(self._inplanes, planes))
193
-
194
- return nn.Sequential(*layers)
195
-
196
- def init_parameters(self):
197
- if self.attnpool is not None:
198
- std = self.attnpool.c_proj.in_features**-0.5
199
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
200
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
201
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
202
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
203
-
204
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
205
- for name, param in resnet_block.named_parameters():
206
- if name.endswith("bn3.weight"):
207
- nn.init.zeros_(param)
208
-
209
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
210
- assert (
211
- unlocked_groups == 0
212
- ), "partial locking not currently supported for this model"
213
- for param in self.parameters():
214
- param.requires_grad = False
215
- if freeze_bn_stats:
216
- freeze_batch_norm_2d(self)
217
-
218
- def stem(self, x):
219
- for conv, bn in [
220
- (self.conv1, self.bn1),
221
- (self.conv2, self.bn2),
222
- (self.conv3, self.bn3),
223
- ]:
224
- x = self.relu(bn(conv(x)))
225
- x = self.avgpool(x)
226
- return x
227
-
228
- def forward(self, x):
229
- x = self.stem(x)
230
- x = self.layer1(x)
231
- x = self.layer2(x)
232
- x = self.layer3(x)
233
- x = self.layer4(x)
234
- x = self.attnpool(x)
235
-
236
- return x
237
-
238
-
239
- class LayerNorm(nn.LayerNorm):
240
- """Subclass torch's LayerNorm to handle fp16."""
241
-
242
- def forward(self, x: torch.Tensor):
243
- orig_type = x.dtype
244
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
245
- return x.to(orig_type)
246
-
247
-
248
- class QuickGELU(nn.Module):
249
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
250
- def forward(self, x: torch.Tensor):
251
- return x * torch.sigmoid(1.702 * x)
252
-
253
-
254
- class ResidualAttentionBlock(nn.Module):
255
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
256
- super().__init__()
257
-
258
- self.attn = nn.MultiheadAttention(d_model, n_head)
259
- self.ln_1 = LayerNorm(d_model)
260
- self.mlp = nn.Sequential(
261
- OrderedDict(
262
- [
263
- ("c_fc", nn.Linear(d_model, d_model * 4)),
264
- ("gelu", act_layer()),
265
- ("c_proj", nn.Linear(d_model * 4, d_model)),
266
- ]
267
- )
268
- )
269
- self.ln_2 = LayerNorm(d_model)
270
-
271
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
272
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
273
-
274
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
275
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
276
- x = x + self.mlp(self.ln_2(x))
277
- return x
278
-
279
-
280
- class Transformer(nn.Module):
281
- def __init__(
282
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
283
- ):
284
- super().__init__()
285
- self.width = width
286
- self.layers = layers
287
- self.resblocks = nn.ModuleList(
288
- [
289
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
290
- for _ in range(layers)
291
- ]
292
- )
293
-
294
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
295
- for r in self.resblocks:
296
- x = r(x, attn_mask=attn_mask)
297
- return x
298
-
299
-
300
- class VisualTransformer(nn.Module):
301
- def __init__(
302
- self,
303
- image_size: int,
304
- patch_size: int,
305
- width: int,
306
- layers: int,
307
- heads: int,
308
- output_dim: int,
309
- act_layer: Callable = nn.GELU,
310
- ):
311
- super().__init__()
312
- self.image_size = image_size
313
- self.output_dim = output_dim
314
- self.conv1 = nn.Conv2d(
315
- in_channels=3,
316
- out_channels=width,
317
- kernel_size=patch_size,
318
- stride=patch_size,
319
- bias=False,
320
- )
321
-
322
- scale = width**-0.5
323
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
324
- self.positional_embedding = nn.Parameter(
325
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
326
- )
327
- self.ln_pre = LayerNorm(width)
328
-
329
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
330
-
331
- self.ln_post = LayerNorm(width)
332
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
333
-
334
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
335
- assert (
336
- unlocked_groups == 0
337
- ), "partial locking not currently supported for this model"
338
- for param in self.parameters():
339
- param.requires_grad = False
340
-
341
- def forward(self, x: torch.Tensor):
342
- x = self.conv1(x) # shape = [*, width, grid, grid]
343
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
344
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
345
- x = torch.cat(
346
- [
347
- self.class_embedding.to(x.dtype)
348
- + torch.zeros(
349
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
350
- ),
351
- x,
352
- ],
353
- dim=1,
354
- ) # shape = [*, grid ** 2 + 1, width]
355
- x = x + self.positional_embedding.to(x.dtype)
356
- x = self.ln_pre(x)
357
-
358
- x = x.permute(1, 0, 2) # NLD -> LND
359
- x = self.text_branch(x)
360
- x = x.permute(1, 0, 2) # LND -> NLD
361
-
362
- x = self.ln_post(x[:, 0, :])
363
-
364
- if self.proj is not None:
365
- x = x @ self.proj
366
-
367
- return x
368
-
369
-
370
- @dataclass
371
- class CLAPVisionCfg:
372
- layers: Union[Tuple[int, int, int, int], int] = 12
373
- width: int = 768
374
- patch_size: int = 16
375
- image_size: Union[Tuple[int, int], int] = 224
376
- timm_model_name: str = (
377
- None # a valid model name overrides layers, width, patch_size
378
- )
379
- timm_model_pretrained: bool = (
380
- False # use (imagenet) pretrained weights for named model
381
- )
382
- timm_pool: str = (
383
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
384
- )
385
- timm_proj: str = (
386
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
387
- )
388
-
389
-
390
- # Audio Config Class
391
- @dataclass
392
- class CLAPAudioCfp:
393
- model_type: str = "PANN"
394
- model_name: str = "Cnn14"
395
- sample_rate: int = 48000
396
- # Param
397
- audio_length: int = 1024
398
- window_size: int = 1024
399
- hop_size: int = 1024
400
- fmin: int = 50
401
- fmax: int = 14000
402
- class_num: int = 527
403
- mel_bins: int = 64
404
- clip_samples: int = 480000
405
-
406
-
407
- @dataclass
408
- class CLAPTextCfg:
409
- context_length: int
410
- vocab_size: int
411
- width: int
412
- heads: int
413
- layers: int
414
- model_type: str
415
-
416
-
417
- class CLAP(nn.Module):
418
- def __init__(
419
- self,
420
- embed_dim: int,
421
- audio_cfg: CLAPAudioCfp,
422
- text_cfg: CLAPTextCfg,
423
- quick_gelu: bool = False,
424
- enable_fusion: bool = False,
425
- fusion_type: str = 'None',
426
- joint_embed_shape: int = 512,
427
- mlp_act: str = 'relu',
428
- ):
429
- super().__init__()
430
- if isinstance(audio_cfg, dict):
431
- audio_cfg = CLAPAudioCfp(**audio_cfg)
432
- if isinstance(text_cfg, dict):
433
- text_cfg = CLAPTextCfg(**text_cfg)
434
-
435
- self.audio_cfg = audio_cfg
436
- self.text_cfg = text_cfg
437
- self.enable_fusion = enable_fusion
438
- self.fusion_type = fusion_type
439
- self.joint_embed_shape = joint_embed_shape
440
- self.mlp_act = mlp_act
441
-
442
-
443
- self.context_length = text_cfg.context_length
444
-
445
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
446
- # memory efficient in recent PyTorch releases (>= 1.10).
447
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
448
- act_layer = QuickGELU if quick_gelu else nn.GELU
449
-
450
- if mlp_act == 'relu':
451
- mlp_act_layer = nn.ReLU()
452
- elif mlp_act == 'gelu':
453
- mlp_act_layer = nn.GELU()
454
- else:
455
- raise NotImplementedError
456
-
457
- # audio branch
458
- # audio branch parameters
459
- if audio_cfg.model_type == "PANN":
460
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
461
- elif audio_cfg.model_type == "HTSAT":
462
- self.audio_branch = create_htsat_model(audio_cfg, enable_fusion, fusion_type)
463
- else:
464
- logging.error(f"Model config for {audio_cfg.model_type} not found")
465
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
466
-
467
- # text branch
468
- # text branch parameters
469
- if text_cfg.model_type == "transformer":
470
- self.text_branch = Transformer(
471
- width=text_cfg.width,
472
- layers=text_cfg.layers,
473
- heads=text_cfg.heads,
474
- act_layer=act_layer,
475
- )
476
- self.vocab_size = text_cfg.vocab_size
477
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
478
- self.positional_embedding = nn.Parameter(
479
- torch.empty(self.context_length, text_cfg.width)
480
- )
481
- self.ln_final = LayerNorm(text_cfg.width)
482
- self.text_transform = MLPLayers(units=[self.joint_embed_shape,
483
- self.joint_embed_shape,
484
- self.joint_embed_shape], dropout=0.1)
485
- self.text_projection = nn.Sequential(
486
- nn.Linear(text_cfg.width, self.joint_embed_shape),
487
- mlp_act_layer,
488
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
489
- )
490
- elif text_cfg.model_type == "bert":
491
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
492
- self.text_transform = MLPLayers(units=[self.joint_embed_shape,
493
- self.joint_embed_shape,
494
- self.joint_embed_shape], dropout=0.1)
495
- self.text_projection = nn.Sequential(
496
- nn.Linear(768, self.joint_embed_shape),
497
- mlp_act_layer,
498
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
499
- )
500
- elif text_cfg.model_type == "roberta":
501
- self.text_branch = RobertaModel.from_pretrained('roberta-base')
502
- self.text_transform = MLPLayers(units=[self.joint_embed_shape,
503
- self.joint_embed_shape,
504
- self.joint_embed_shape], dropout=0.1)
505
- self.text_projection = nn.Sequential(
506
- nn.Linear(768, self.joint_embed_shape),
507
- mlp_act_layer,
508
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
509
- )
510
- elif text_cfg.model_type == "bart":
511
- self.text_branch = BartModel.from_pretrained('facebook/bart-base')
512
- self.text_transform = MLPLayers(units=[self.joint_embed_shape,
513
- self.joint_embed_shape,
514
- self.joint_embed_shape], dropout=0.1)
515
- self.text_projection = nn.Sequential(
516
- nn.Linear(768, self.joint_embed_shape),
517
- mlp_act_layer,
518
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
519
- )
520
- else:
521
- logging.error(f"Model config for {text_cfg.model_type} not found")
522
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
523
- self.text_branch_type = text_cfg.model_type
524
- # text branch parameters
525
-
526
- # audio branch parameters
527
- self.audio_transform = MLPLayers(units=[self.joint_embed_shape,
528
- self.joint_embed_shape,
529
- self.joint_embed_shape], dropout=0.1)
530
-
531
- # below here is text branch parameters
532
-
533
- # ============================================================================================================
534
- self.audio_projection = nn.Sequential(
535
- nn.Linear(embed_dim, self.joint_embed_shape),
536
- mlp_act_layer,
537
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape)
538
- )
539
-
540
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
541
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
542
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
543
-
544
- self.init_text_branch_parameters()
545
-
546
- def init_text_branch_parameters(self):
547
- if self.text_branch_type == "transformer":
548
- nn.init.normal_(self.token_embedding.weight, std=0.02)
549
- nn.init.normal_(self.positional_embedding, std=0.01)
550
- proj_std = (self.text_branch.width**-0.5) * (
551
- (2 * self.text_branch.layers) ** -0.5
552
- )
553
- attn_std = self.text_branch.width**-0.5
554
- fc_std = (2 * self.text_branch.width) ** -0.5
555
- for block in self.text_branch.resblocks:
556
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
557
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
558
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
559
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
560
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
561
- width = self.text_branch.embeddings.word_embeddings.weight.shape[-1]
562
- elif self.text_branch_type == "bart":
563
- width = self.text_branch.shared.weight.shape[-1]
564
- else:
565
- width = self.text_branch.width
566
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
567
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
568
-
569
- # deprecated
570
- # if hasattr(self.visual, 'init_parameters'):
571
- # self.visual.init_parameters()
572
-
573
- # if self.text_projection is not None:
574
- # nn.init.normal_(self.text_projection, std=width**-0.5)
575
-
576
- def build_attention_mask(self):
577
- # lazily create causal attention mask, with full attention between the vision tokens
578
- # pytorch uses additive attention mask; fill with -inf
579
- mask = torch.empty(self.context_length, self.context_length)
580
- mask.fill_(float("-inf"))
581
- mask.triu_(1) # zero out the lower diagonal
582
- return mask
583
-
584
- def encode_audio(self, audio, device):
585
- return self.audio_branch(audio, mixup_lambda=None, device=device) # mix lambda needs to add
586
-
587
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
588
- # tmp = {}
589
- # for k in x[0].keys():
590
- # tmp[k] = []
591
- # for i in range(len(x)):
592
- # tmp[k].append(x[i][k][:77])
593
- # for k in x[0].keys():
594
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
595
- # return tmp
596
-
597
- def encode_text(self, text, device):
598
- if self.text_branch_type == "transformer":
599
- text = text.to(device=device, non_blocking=True)
600
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
601
-
602
- x = x + self.positional_embedding
603
- x = x.permute(1, 0, 2) # NLD -> LND
604
- x = self.text_branch(x, attn_mask=self.attn_mask)
605
- x = x.permute(1, 0, 2) # LND -> NLD
606
- x = self.ln_final(x)
607
-
608
- # x.shape = [batch_size, n_ctx, transformer.width]
609
- # take features from the eot embedding (eot_token is the highest number in each sequence)
610
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
611
- elif self.text_branch_type == "bert":
612
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
613
- # text = BatchEncoding(text)
614
- x = self.text_branch(
615
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
616
- attention_mask=text["attention_mask"].to(
617
- device=device, non_blocking=True
618
- ),
619
- token_type_ids=text["token_type_ids"].to(
620
- device=device, non_blocking=True
621
- ),
622
- )["pooler_output"]
623
- x = self.text_projection(x)
624
- elif self.text_branch_type == "roberta":
625
- x = self.text_branch(
626
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
627
- attention_mask=text["attention_mask"].to(
628
- device=device, non_blocking=True
629
- ),
630
- )["pooler_output"]
631
- x = self.text_projection(x)
632
- elif self.text_branch_type == "bart":
633
- x = torch.mean(self.text_branch(
634
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
635
- attention_mask=text["attention_mask"].to(
636
- device=device, non_blocking=True
637
- ),
638
- )["encoder_last_hidden_state"],axis=1)
639
- x = self.text_projection(x)
640
- else:
641
- logging.error(f"Model type {self.text_branch_type} not found")
642
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
643
- return x
644
-
645
- def forward(self, audio, text, device=None):
646
- """Forward audio and text into the CLAP
647
- Parameters
648
- ----------
649
- audio: torch.Tensor (batch_size, audio_length)
650
- the time-domain audio input / the batch of mel_spec and longer list.
651
- text: torch.Tensor () // need to add
652
- the text token input
653
- """
654
- if device is None:
655
- if audio is not None:
656
- device = audio.device
657
- elif text is not None:
658
- device = text.device
659
- if audio is None and text is None:
660
- # a hack to get the logit scale
661
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
662
- elif audio is None:
663
- return self.encode_text(text, device=device)
664
- elif text is None:
665
- return self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
666
- audio_features = self.audio_projection(self.encode_audio(audio, device=device)["embedding"])
667
- audio_features = F.normalize(audio_features, dim=-1)
668
-
669
- text_features = self.encode_text(
670
- text, device=device
671
- )
672
- # print("text_features", text_features)
673
- # print("text_features.shape", text_features.shape)
674
- # print("text_features.type", type(text_features))
675
- text_features = F.normalize(text_features, dim=-1)
676
-
677
- audio_features_mlp = self.audio_transform(audio_features)
678
- text_features_mlp = self.text_transform(text_features)
679
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
680
- return (
681
- audio_features,
682
- text_features,
683
- audio_features_mlp,
684
- text_features_mlp,
685
- self.logit_scale_a.exp(),
686
- self.logit_scale_t.exp(),
687
- )
688
-
689
- def get_logit_scale(self):
690
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
691
-
692
- def get_text_embedding(self, data):
693
- """Get the text embedding from the model
694
- Parameters
695
- ----------
696
- data: torch.Tensor
697
- a tensor of text embedding
698
- Returns
699
- ----------
700
- text_embed: torch.Tensor
701
- a tensor of text_embeds (N, D)
702
- """
703
- device = next(self.parameters()).device
704
- for k in data:
705
- data[k] = data[k].to(device)
706
- text_embeds = self.encode_text(data, device=device)
707
- text_embeds = F.normalize(text_embeds, dim=-1)
708
-
709
- return text_embeds
710
-
711
- def get_audio_embedding(self, data):
712
- """Get the audio embedding from the model
713
- Parameters
714
- ----------
715
- data: a list of dict
716
- the audio input dict list from 'get_audio_feature' method
717
- Returns
718
- ----------
719
- audio_embed: torch.Tensor
720
- a tensor of audio_embeds (N, D)
721
- """
722
- device = next(self.parameters()).device
723
- input_dict = {}
724
- keys = data[0].keys()
725
- for k in keys:
726
- input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(device)
727
-
728
- audio_embeds = self.audio_projection(self.encode_audio(input_dict, device=device)["embedding"])
729
- audio_embeds = F.normalize(audio_embeds, dim=-1)
730
-
731
- return audio_embeds
732
-
733
-
734
-
735
- def audio_infer(self, audio, hopsize=None, device=None):
736
- """Forward one audio and produce the audio embedding
737
- Parameters
738
- ----------
739
- audio: (audio_length)
740
- the time-domain audio input, notice that it must be only one input
741
- hopsize: int
742
- the overlap hopsize as the sliding window
743
- Returns
744
- ----------
745
- output_dict: {
746
- key: [n, (embedding_shape)] if "HTS-AT"
747
- or
748
- key: [(embedding_shape)] if "PANN"
749
- }
750
- the list of key values of the audio branch
751
- """
752
-
753
- assert not self.training, "the inference mode must be run at eval stage"
754
- output_dict = {}
755
- # PANN
756
- if self.audio_cfg.model_type == "PANN":
757
- audio_input = audio.unsqueeze(dim=0)
758
- output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
759
- elif self.audio_cfg.model_type == "HTSAT":
760
- # repeat
761
- audio_len = len(audio)
762
- k = self.audio_cfg.clip_samples // audio_len
763
- if k > 1:
764
- audio = audio.repeat(k)
765
- audio_len = len(audio)
766
-
767
- if hopsize is None:
768
- hopsize = min(hopsize, audio_len)
769
-
770
- if audio_len > self.audio_cfg.clip_samples:
771
- audio_input = [
772
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
773
- for pos in range(
774
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
775
- )
776
- ]
777
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
778
- audio_input = torch.stack(audio_input)
779
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
780
- else:
781
- audio_input = audio.unsqueeze(dim=0)
782
- output_dict[key] = self.encode_audio(audio_input, device=device)[key].squeeze(dim=0)
783
-
784
- return output_dict
785
-
786
-
787
- def convert_weights_to_fp16(model: nn.Module):
788
- """Convert applicable model parameters to fp16"""
789
-
790
- def _convert_weights_to_fp16(l):
791
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
792
- l.weight.data = l.weight.data.half()
793
- if l.bias is not None:
794
- l.bias.data = l.bias.data.half()
795
-
796
- if isinstance(l, nn.MultiheadAttention):
797
- for attr in [
798
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
799
- "in_proj_bias",
800
- "bias_k",
801
- "bias_v",
802
- ]:
803
- tensor = getattr(l, attr)
804
- if tensor is not None:
805
- tensor.data = tensor.data.half()
806
-
807
- for name in ["text_projection", "proj"]:
808
- if hasattr(l, name):
809
- attr = getattr(l, name)
810
- if attr is not None:
811
- attr.data = attr.data.half()
812
-
813
- model.apply(_convert_weights_to_fp16)
814
-
815
-
816
- # Ignore the state dict of the vision part
817
- def build_model_from_openai_state_dict(state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = 'None'):
818
-
819
- embed_dim = model_cfg["embed_dim"]
820
- audio_cfg = model_cfg["audio_cfg"]
821
- text_cfg = model_cfg["text_cfg"]
822
- context_length = state_dict["positional_embedding"].shape[0]
823
- vocab_size = state_dict["token_embedding.weight"].shape[0]
824
- transformer_width = state_dict["ln_final.weight"].shape[0]
825
- transformer_heads = transformer_width // 64
826
- transformer_layers = len(
827
- set(
828
- k.split(".")[2]
829
- for k in state_dict
830
- if k.startswith(f"transformer.resblocks")
831
- )
832
- )
833
-
834
- audio_cfg = CLAPAudioCfp(**audio_cfg)
835
- text_cfg = CLAPTextCfg(**text_cfg)
836
-
837
- model = CLAP(
838
- embed_dim,
839
- audio_cfg=audio_cfg,
840
- text_cfg=text_cfg,
841
- quick_gelu=True, # OpenAI models were trained with QuickGELU
842
- enable_fusion=enable_fusion,
843
- fusion_type=fusion_type
844
- )
845
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
846
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
847
- pop_keys = list(state_dict.keys())[::]
848
- # pop the visual branch saved weights
849
- for key in pop_keys:
850
- if key.startswith("visual."):
851
- state_dict.pop(key, None)
852
-
853
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
854
- state_dict.pop(key, None)
855
-
856
- # not use fp16
857
- # convert_weights_to_fp16(model)
858
- model.load_state_dict(state_dict, strict=False)
859
- return model.eval()
860
-
861
-
862
- def trace_model(model, batch_size=256, device=torch.device("cpu")):
863
- model.eval()
864
- audio_length = model.audio_cfg.audio_length
865
- example_audio = torch.ones((batch_size, audio_length), device=device)
866
- example_text = torch.zeros(
867
- (batch_size, model.context_length), dtype=torch.int, device=device
868
- )
869
- model = torch.jit.trace_module(
870
- model,
871
- inputs=dict(
872
- forward=(example_audio, example_text),
873
- encode_text=(example_text,),
874
- encode_image=(example_audio,),
875
- ),
876
- )
877
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
878
- return model
 
 
 
 
 
1
 
 
 
 
 
 
2
 
3
+ import cv2
4
+ import os
5
+ from natsort import natsorted
6
 
7
+ image_folder = '/Users/playment/Parth /Codes/data'
8
+ video_name = '/Users/playment/Parth /Codes/AI_VideoEditing/video.avi'
 
 
9
 
10
+ images = [img for img in os.listdir(image_folder) if img.endswith(".jpg")]
11
+ frame = cv2.imread(os.path.join(image_folder, images[0]))
12
+ height, width, layers = frame.shape
13
+ images = natsorted(images)
14
+ video = cv2.VideoWriter(video_name, 0, 60, (width,height))
15
 
16
+ for image in images:
17
+ video.write(cv2.imread(os.path.join(image_folder, image)))
 
 
 
18
 
19
+ cv2.destroyAllWindows()
20
+ video.release()