Gertie01 commited on
Commit
c4114ed
1 Parent(s): aad1cac

Create musiclm_pytorch.py

Browse files
Files changed (1) hide show
  1. musiclm_pytorch.py +555 -0
musiclm_pytorch.py ADDED
@@ -0,0 +1,555 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn, einsum
4
+
5
+ from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking
6
+
7
+ from audiolm_pytorch import AudioLM
8
+
9
+ from x_clip.tokenizer import tokenizer
10
+ from vector_quantize_pytorch import ResidualVQ
11
+
12
+ from einops import rearrange, repeat, reduce, pack, unpack
13
+
14
+ from beartype.typing import List, Optional, Tuple
15
+ from beartype import beartype
16
+
17
+ # functions
18
+
19
+ def exists(val):
20
+ return val is not None
21
+
22
+ def default(val, d):
23
+ return val if exists(val) else d
24
+
25
+ def round_down_nearest_multiple(n, divisor):
26
+ return n // divisor * divisor
27
+
28
+ # tensor functions
29
+
30
+ def log(t, eps = 1e-20):
31
+ return torch.log(t.clamp(min = eps))
32
+
33
+ def l2norm(t):
34
+ return F.normalize(t, p = 2, dim = -1)
35
+
36
+ # 2d sinusoidal positional embedding
37
+ # simple vit paper shows it is good enough compared to learned
38
+
39
+ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
40
+ _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
41
+
42
+ y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij')
43
+ assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb'
44
+
45
+ omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1)
46
+ omega = 1. / (temperature ** omega)
47
+
48
+ y = y.flatten()[:, None] * omega[None, :]
49
+ x = x.flatten()[:, None] * omega[None, :]
50
+
51
+ pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1)
52
+ pe = pe.type(dtype)
53
+
54
+ return rearrange(pe, '(h w) d -> h w d', h = h, w = w)
55
+
56
+ # biasless layernorm
57
+
58
+ class LayerNorm(nn.Module):
59
+ def __init__(self, dim):
60
+ super().__init__()
61
+ self.gamma = nn.Parameter(torch.ones(dim))
62
+ self.register_buffer('beta', torch.zeros(dim))
63
+
64
+ def forward(self, x):
65
+ return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
66
+
67
+ # feedforward
68
+
69
+ class GEGLU(nn.Module):
70
+ def forward(self, x):
71
+ x, gate = x.chunk(2, dim = -1)
72
+ return F.gelu(gate) * x
73
+
74
+ def FeedForward(dim, mult = 4, dropout = 0.):
75
+ dim_hidden = int(dim * mult * 2 / 3)
76
+
77
+ return nn.Sequential(
78
+ LayerNorm(dim),
79
+ nn.Linear(dim, dim_hidden * 2, bias = False),
80
+ GEGLU(),
81
+ nn.Dropout(dropout),
82
+ nn.Linear(dim_hidden, dim, bias = False)
83
+ )
84
+
85
+ # attention
86
+
87
+ class Attention(nn.Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ causal = False,
92
+ dim_head = 64,
93
+ heads = 8,
94
+ dropout = 0.
95
+ ):
96
+ super().__init__()
97
+ self.heads = heads
98
+ self.scale = dim_head ** -0.5
99
+ self.causal = causal
100
+ inner_dim = dim_head * heads
101
+
102
+ self.norm = LayerNorm(dim)
103
+
104
+ self.attn_dropout = nn.Dropout(dropout)
105
+
106
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
107
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
108
+
109
+ self.to_out = nn.Sequential(
110
+ nn.Linear(inner_dim, dim, bias = False),
111
+ nn.Dropout(dropout)
112
+ )
113
+
114
+ def forward(
115
+ self,
116
+ x,
117
+ mask = None
118
+ ):
119
+ b, n, _, device = *x.shape, x.device
120
+
121
+ # prenorm
122
+
123
+ x = self.norm(x)
124
+
125
+ # project for queries, keys, values
126
+
127
+ q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)
128
+
129
+ # split for multi-headed attention
130
+
131
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))
132
+
133
+ q = q * self.scale
134
+
135
+ # similarities
136
+
137
+ sim = einsum('b h i d, b h j d -> b h i j', q, k)
138
+
139
+ if exists(mask):
140
+ mask = rearrange(mask, 'b j -> b 1 1 j')
141
+ sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
142
+
143
+ if self.causal:
144
+ i, j = sim.shape[-2:]
145
+ causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
146
+ sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
147
+
148
+ # attention
149
+
150
+ attn = sim.softmax(dim = -1)
151
+ attn = self.attn_dropout(attn)
152
+
153
+ # aggregate
154
+
155
+ out = einsum('b h i j, b h j d -> b h i d', attn, v)
156
+
157
+ # merge heads
158
+
159
+ out = rearrange(out, 'b h n d -> b n (h d)')
160
+ return self.to_out(out)
161
+
162
+ # transformer
163
+
164
+ class Transformer(nn.Module):
165
+ def __init__(
166
+ self,
167
+ dim,
168
+ depth,
169
+ dim_head = 64,
170
+ heads = 8,
171
+ attn_dropout = 0.,
172
+ ff_mult = 4,
173
+ ff_dropout = 0.
174
+ ):
175
+ super().__init__()
176
+ self.layers = nn.ModuleList([])
177
+ for _ in range(depth):
178
+ self.layers.append(nn.ModuleList([
179
+ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
180
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
181
+ ]))
182
+
183
+ def forward(self, x, mask = None):
184
+
185
+ for attn, ff in self.layers:
186
+ x = attn(x, mask = mask) + x
187
+ x = ff(x) + x
188
+
189
+ return x
190
+
191
+ # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778
192
+
193
+ def pair(t):
194
+ return (t, t) if not isinstance(t, tuple) else t
195
+
196
+ class AudioSpectrogramTransformer(nn.Module):
197
+ def __init__(
198
+ self,
199
+ dim,
200
+ depth,
201
+ patch_size = 16,
202
+ dim_head = 64,
203
+ heads = 8,
204
+ attn_dropout = 0.,
205
+ ff_mult = 4,
206
+ ff_dropout = 0.,
207
+ spec_n_fft = 128,
208
+ spec_power = 2,
209
+ spec_win_length = 24,
210
+ spec_hop_length = None,
211
+ spec_pad = 0,
212
+ spec_center = True,
213
+ spec_pad_mode = 'reflect',
214
+ spec_aug_stretch_factor = 0.8,
215
+ spec_aug_freq_mask = 80,
216
+ spec_aug_time_mask = 80
217
+
218
+ ):
219
+ super().__init__()
220
+ self.dim = dim
221
+
222
+ self.patch_size = pair(patch_size)
223
+ self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)
224
+
225
+ self.spec = Spectrogram(
226
+ n_fft = spec_n_fft,
227
+ power = spec_power,
228
+ win_length = spec_win_length,
229
+ hop_length = spec_hop_length,
230
+ pad = spec_pad,
231
+ center = spec_center,
232
+ pad_mode = spec_pad_mode
233
+ )
234
+
235
+ # SpecAugment - seems to be widely used in audio field https://arxiv.org/abs/1904.08779
236
+
237
+ self.aug = torch.nn.Sequential(
238
+ TimeStretch(spec_aug_stretch_factor, fixed_rate=True),
239
+ FrequencyMasking(freq_mask_param = spec_aug_freq_mask),
240
+ TimeMasking(time_mask_param = spec_aug_time_mask),
241
+ )
242
+
243
+ self.transformer = Transformer(
244
+ dim = dim,
245
+ depth = depth,
246
+ dim_head = dim_head,
247
+ heads = heads,
248
+ attn_dropout = attn_dropout,
249
+ ff_mult = ff_mult,
250
+ ff_dropout = ff_dropout
251
+ )
252
+
253
+ self.norm = LayerNorm(dim)
254
+
255
+ def forward(self, x):
256
+ x = self.spec(x)
257
+
258
+ if self.training:
259
+ x = self.aug(x)
260
+
261
+ # automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes
262
+
263
+ height, width = x.shape[-2:]
264
+ patch_height, patch_width = self.patch_size
265
+
266
+ rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width)))
267
+
268
+ if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed
269
+ print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer')
270
+
271
+ x = x[..., :rounded_height, :rounded_width]
272
+
273
+ # to patches
274
+
275
+ x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width)
276
+ x = self.to_patch_tokens(x)
277
+
278
+ # 2d sinusoidal positional embedding
279
+
280
+ x = rearrange(x, 'b c h w -> b h w c')
281
+ x = x + posemb_sincos_2d(x)
282
+
283
+ # attention, what else
284
+
285
+ x = rearrange(x, 'b ... c -> b (...) c')
286
+
287
+ x = self.transformer(x)
288
+
289
+ # final global average and norm (most recent papers show this is superior to CLS token)
290
+
291
+ x = reduce(x, 'b n d -> b d', 'mean')
292
+
293
+ return self.norm(x)
294
+
295
+ # text transformer
296
+
297
+ @beartype
298
+ class TextTransformer(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim,
302
+ depth,
303
+ num_tokens = tokenizer.vocab_size,
304
+ max_seq_len = 256,
305
+ dim_head = 64,
306
+ heads = 8,
307
+ attn_dropout = 0.,
308
+ ff_dropout = 0.,
309
+ ff_mult = 4,
310
+ pad_id = 0
311
+ ):
312
+ super().__init__()
313
+ self.dim = dim
314
+
315
+ self.token_emb = nn.Embedding(num_tokens, dim)
316
+ self.pos_emb = nn.Embedding(max_seq_len, dim)
317
+
318
+ self.cls_token = nn.Parameter(torch.randn(dim))
319
+
320
+ self.transformer = Transformer(
321
+ dim = dim,
322
+ depth = depth,
323
+ dim_head = dim_head,
324
+ heads = heads,
325
+ attn_dropout = attn_dropout,
326
+ ff_dropout = ff_dropout,
327
+ ff_mult = ff_mult
328
+ )
329
+
330
+ self.pad_id = pad_id
331
+ self.norm = LayerNorm(dim)
332
+
333
+ def forward(
334
+ self,
335
+ x = None,
336
+ raw_texts: Optional[List[str]] = None,
337
+ mask = None
338
+ ):
339
+ assert exists(x) ^ exists(raw_texts)
340
+
341
+ if exists(raw_texts):
342
+ x = tokenizer.tokenize(raw_texts)
343
+
344
+ if not exists(mask):
345
+ mask = x != self.pad_id
346
+
347
+ b, n, device = *x.shape, x.device
348
+
349
+ # token embedding + positional embedding
350
+
351
+ x = self.token_emb(x)
352
+ x = x + self.pos_emb(torch.arange(n, device = device))
353
+
354
+ # cls tokens, as in bert
355
+
356
+ cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
357
+ x, ps = pack([cls_tokens, x], 'b * d')
358
+
359
+ # account for attending to cls token with self attention mask
360
+
361
+ mask = F.pad(mask, (1, 0), value = True)
362
+
363
+ # attention
364
+
365
+ x = self.transformer(x, mask = mask)
366
+
367
+ # unpack the cls tokens
368
+
369
+ cls_tokens, _ = unpack(x, ps, 'b * d')
370
+
371
+ return self.norm(cls_tokens)
372
+
373
+ # main classes
374
+
375
+ @beartype
376
+ class MuLaN(nn.Module):
377
+ def __init__(
378
+ self,
379
+ audio_transformer: AudioSpectrogramTransformer,
380
+ text_transformer: TextTransformer,
381
+ dim_latent = 128, # they use 128
382
+ decoupled_contrastive_learning = True, # think this was used, make it optional
383
+ ):
384
+ super().__init__()
385
+ self.dim_latent = dim_latent
386
+
387
+ self.audio = audio_transformer
388
+ self.text = text_transformer
389
+
390
+ self.temperature = nn.Parameter(torch.tensor(1.))
391
+
392
+ self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
393
+ self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)
394
+
395
+ self.decoupled_contrastive_learning = decoupled_contrastive_learning
396
+
397
+ def get_audio_latents(
398
+ self,
399
+ wavs
400
+ ):
401
+ audio_embeds = self.audio(wavs)
402
+ audio_latents = self.audio_to_latents(audio_embeds)
403
+ return l2norm(audio_latents)
404
+
405
+ def get_text_latents(
406
+ self,
407
+ texts = None,
408
+ raw_texts: Optional[List[str]] = None
409
+ ):
410
+ text_embeds = self.text(texts)
411
+ text_latents = self.text_to_latents(text_embeds)
412
+ return l2norm(text_latents)
413
+
414
+ def forward(
415
+ self,
416
+ wavs,
417
+ texts = None,
418
+ raw_texts: Optional[List[str]] = None,
419
+ return_similarities = False
420
+ ):
421
+ batch, device = wavs.shape[0], wavs.device
422
+
423
+ audio_latents = self.get_audio_latents(wavs)
424
+ text_latents = self.get_text_latents(texts, raw_texts = raw_texts)
425
+
426
+ cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)
427
+
428
+ assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'
429
+
430
+ if return_similarities:
431
+ return cosine_sim
432
+
433
+ cosine_sim = cosine_sim * self.temperature.exp()
434
+
435
+ cosine_sim_exp = cosine_sim.exp()
436
+
437
+ numerator = cosine_sim_exp.diag()
438
+
439
+ if self.decoupled_contrastive_learning:
440
+ eye = torch.eye(batch, device = device)
441
+ cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)
442
+
443
+ denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
444
+
445
+ contrastive_loss = -log(numerator / denominator)
446
+ return contrastive_loss.mean()
447
+
448
+ # music lm
449
+
450
+ @beartype
451
+ class MuLaNEmbedQuantizer(nn.Module):
452
+ def __init__(
453
+ self,
454
+ mulan: MuLaN,
455
+ conditioning_dims: Tuple[int, ...],
456
+ rq_num_quantizers = 8,
457
+ rq_ema_decay = 0.9,
458
+ codebook_size = 1024,
459
+ namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),
460
+
461
+ ):
462
+ super().__init__()
463
+ self.mulan = mulan
464
+
465
+ assert len(namespaces) > 0
466
+ self.namespaces = namespaces
467
+ self.conditioning_dims = conditioning_dims
468
+
469
+ assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'
470
+
471
+ dim = mulan.dim_latent
472
+
473
+ self.rq = ResidualVQ(
474
+ dim = dim,
475
+ num_quantizers = rq_num_quantizers,
476
+ codebook_size = codebook_size,
477
+ decay = rq_ema_decay,
478
+ commitment_weight = 0, # only use EMA to update codebooks
479
+ kmeans_init = True,
480
+ threshold_ema_dead_code = 2,
481
+ quantize_dropout = False # no quantize dropout
482
+ )
483
+
484
+ self.dim = dim
485
+ self.num_codebooks = rq_num_quantizers
486
+
487
+ self.cond_embeddings = nn.ParameterDict({})
488
+
489
+ for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
490
+ cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
491
+ nn.init.normal_(cond_embeddings, std = 0.02)
492
+
493
+ self.cond_embeddings[namespace] = cond_embeddings
494
+
495
+ self.set_default_namespace(namespaces[0])
496
+
497
+ def set_default_namespace(self, namespace):
498
+ self._default_namespace = namespace
499
+
500
+ def forward(
501
+ self,
502
+ wavs = None,
503
+ texts = None,
504
+ namespace = None
505
+ ):
506
+ assert exists(wavs) ^ exists(texts)
507
+
508
+ namespace = default(namespace, self._default_namespace)
509
+ assert namespace in self.namespaces, f'namespace {namespace} not found'
510
+ cond_embeddings = self.cond_embeddings[namespace]
511
+
512
+ with torch.no_grad():
513
+ self.mulan.eval()
514
+
515
+ # sound and language live in joint embedding space because of contrastive learning
516
+
517
+ if exists(wavs):
518
+ latents = self.mulan.get_audio_latents(wavs)
519
+ elif exists(texts):
520
+ latents = self.mulan.get_text_latents(texts)
521
+
522
+ _, indices, _ = self.rq(latents)
523
+
524
+ batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]
525
+
526
+ cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
527
+ indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)
528
+
529
+ cond_embeddings = cond_embeddings.gather(2, indices)
530
+ return rearrange(cond_embeddings, 'b q 1 d -> b q d')
531
+
532
+ @beartype
533
+ class MusicLM(nn.Module):
534
+ def __init__(
535
+ self,
536
+ audio_lm: AudioLM,
537
+ mulan_embed_quantizer: MuLaNEmbedQuantizer
538
+ ):
539
+ super().__init__()
540
+ self.mulan_embed_quantizer = mulan_embed_quantizer
541
+ self.audio_lm = audio_lm
542
+
543
+ @torch.no_grad()
544
+ def forward(
545
+ self,
546
+ raw_texts: List[str],
547
+ **audio_lm_kwargs
548
+ ):
549
+ self.eval()
550
+
551
+ texts = tokenizer.tokenize(raw_texts)
552
+ cond_tokens = self.mulan_embed_quantizer(texts = texts)
553
+
554
+ wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs)
555
+ return wavs