Gregniuki commited on
Commit
7726dc4
·
verified ·
1 Parent(s): ea43fca

Delete model/modules.py

Browse files
Files changed (1) hide show
  1. model/modules.py +0 -575
model/modules.py DELETED
@@ -1,575 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
- from typing import Optional
12
- import math
13
-
14
- import torch
15
- from torch import nn
16
- import torch.nn.functional as F
17
- import torchaudio
18
-
19
- from einops import rearrange
20
- from x_transformers.x_transformers import apply_rotary_pos_emb
21
-
22
-
23
- # raw wav to mel spec
24
-
25
- class MelSpec(nn.Module):
26
- def __init__(
27
- self,
28
- filter_length = 1024,
29
- hop_length = 256,
30
- win_length = 1024,
31
- n_mel_channels = 100,
32
- target_sample_rate = 24_000,
33
- normalize = False,
34
- power = 1,
35
- norm = None,
36
- center = True,
37
- ):
38
- super().__init__()
39
- self.n_mel_channels = n_mel_channels
40
-
41
- self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
- sample_rate = target_sample_rate,
43
- n_fft = filter_length,
44
- win_length = win_length,
45
- hop_length = hop_length,
46
- n_mels = n_mel_channels,
47
- power = power,
48
- center = center,
49
- normalized = normalize,
50
- norm = norm,
51
- )
52
-
53
- self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
-
55
- def forward(self, inp):
56
- if len(inp.shape) == 3:
57
- inp = rearrange(inp, 'b 1 nw -> b nw')
58
-
59
- assert len(inp.shape) == 2
60
-
61
- if self.dummy.device != inp.device:
62
- self.to(inp.device)
63
-
64
- mel = self.mel_stft(inp)
65
- mel = mel.clamp(min = 1e-5).log()
66
- return mel
67
-
68
-
69
- # sinusoidal position embedding
70
-
71
- class SinusPositionEmbedding(nn.Module):
72
- def __init__(self, dim):
73
- super().__init__()
74
- self.dim = dim
75
-
76
- def forward(self, x, scale=1000):
77
- device = x.device
78
- half_dim = self.dim // 2
79
- emb = math.log(10000) / (half_dim - 1)
80
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
- return emb
84
-
85
-
86
- # convolutional position embedding
87
-
88
- class ConvPositionEmbedding(nn.Module):
89
- def __init__(self, dim, kernel_size = 31, groups = 16):
90
- super().__init__()
91
- assert kernel_size % 2 != 0
92
- self.conv1d = nn.Sequential(
93
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
- nn.Mish(),
95
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
- nn.Mish(),
97
- )
98
-
99
- def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
- if mask is not None:
101
- mask = mask[..., None]
102
- x = x.masked_fill(~mask, 0.)
103
-
104
- x = rearrange(x, 'b n d -> b d n')
105
- x = self.conv1d(x)
106
- out = rearrange(x, 'b d n -> b n d')
107
-
108
- if mask is not None:
109
- out = out.masked_fill(~mask, 0.)
110
-
111
- return out
112
-
113
-
114
- # rotary positional embedding related
115
-
116
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
- # has some connection to NTK literature
119
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
- # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
- theta *= theta_rescale_factor ** (dim / (dim - 2))
122
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
- t = torch.arange(end, device=freqs.device) # type: ignore
124
- freqs = torch.outer(t, freqs).float() # type: ignore
125
- freqs_cos = torch.cos(freqs) # real part
126
- freqs_sin = torch.sin(freqs) # imaginary part
127
- return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
-
129
- def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
- # length = length if isinstance(length, int) else length.max()
131
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
- pos = start.unsqueeze(1) + (
133
- torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
- scale.unsqueeze(1)).long()
135
- # avoid extra long error.
136
- pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
- return pos
138
-
139
-
140
- # Global Response Normalization layer (Instance Normalization ?)
141
-
142
- class GRN(nn.Module):
143
- def __init__(self, dim):
144
- super().__init__()
145
- self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
- self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
-
148
- def forward(self, x):
149
- Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
- return self.gamma * (x * Nx) + self.beta + x
152
-
153
-
154
- # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
- # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
-
157
- class ConvNeXtV2Block(nn.Module):
158
- def __init__(
159
- self,
160
- dim: int,
161
- intermediate_dim: int,
162
- dilation: int = 1,
163
- ):
164
- super().__init__()
165
- padding = (dilation * (7 - 1)) // 2
166
- self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
- self.norm = nn.LayerNorm(dim, eps=1e-6)
168
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
- self.act = nn.GELU()
170
- self.grn = GRN(intermediate_dim)
171
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
-
173
- def forward(self, x: torch.Tensor) -> torch.Tensor:
174
- residual = x
175
- x = x.transpose(1, 2) # b n d -> b d n
176
- x = self.dwconv(x)
177
- x = x.transpose(1, 2) # b d n -> b n d
178
- x = self.norm(x)
179
- x = self.pwconv1(x)
180
- x = self.act(x)
181
- x = self.grn(x)
182
- x = self.pwconv2(x)
183
- return residual + x
184
-
185
-
186
- # AdaLayerNormZero
187
- # return with modulated x for attn input, and params for later mlp modulation
188
-
189
- class AdaLayerNormZero(nn.Module):
190
- def __init__(self, dim):
191
- super().__init__()
192
-
193
- self.silu = nn.SiLU()
194
- self.linear = nn.Linear(dim, dim * 6)
195
-
196
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
-
198
- def forward(self, x, emb = None):
199
- emb = self.linear(self.silu(emb))
200
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
-
202
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
-
205
-
206
- # AdaLayerNormZero for final layer
207
- # return only with modulated x for attn input, cuz no more mlp modulation
208
-
209
- class AdaLayerNormZero_Final(nn.Module):
210
- def __init__(self, dim):
211
- super().__init__()
212
-
213
- self.silu = nn.SiLU()
214
- self.linear = nn.Linear(dim, dim * 2)
215
-
216
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
-
218
- def forward(self, x, emb):
219
- emb = self.linear(self.silu(emb))
220
- scale, shift = torch.chunk(emb, 2, dim=1)
221
-
222
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
- return x
224
-
225
-
226
- # FeedForward
227
-
228
- class FeedForward(nn.Module):
229
- def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
- super().__init__()
231
- inner_dim = int(dim * mult)
232
- dim_out = dim_out if dim_out is not None else dim
233
-
234
- activation = nn.GELU(approximate=approximate)
235
- project_in = nn.Sequential(
236
- nn.Linear(dim, inner_dim),
237
- activation
238
- )
239
- self.ff = nn.Sequential(
240
- project_in,
241
- nn.Dropout(dropout),
242
- nn.Linear(inner_dim, dim_out)
243
- )
244
-
245
- def forward(self, x):
246
- return self.ff(x)
247
-
248
-
249
- # Attention with possible joint part
250
- # modified from diffusers/src/diffusers/models/attention_processor.py
251
-
252
- class Attention(nn.Module):
253
- def __init__(
254
- self,
255
- processor: JointAttnProcessor | AttnProcessor,
256
- dim: int,
257
- heads: int = 8,
258
- dim_head: int = 64,
259
- dropout: float = 0.0,
260
- context_dim: Optional[int] = None, # if not None -> joint attention
261
- context_pre_only = None,
262
- ):
263
- super().__init__()
264
-
265
- if not hasattr(F, "scaled_dot_product_attention"):
266
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
-
268
- self.processor = processor
269
-
270
- self.dim = dim
271
- self.heads = heads
272
- self.inner_dim = dim_head * heads
273
- self.dropout = dropout
274
-
275
- self.context_dim = context_dim
276
- self.context_pre_only = context_pre_only
277
-
278
- self.to_q = nn.Linear(dim, self.inner_dim)
279
- self.to_k = nn.Linear(dim, self.inner_dim)
280
- self.to_v = nn.Linear(dim, self.inner_dim)
281
-
282
- if self.context_dim is not None:
283
- self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
- self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
- if self.context_pre_only is not None:
286
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
-
288
- self.to_out = nn.ModuleList([])
289
- self.to_out.append(nn.Linear(self.inner_dim, dim))
290
- self.to_out.append(nn.Dropout(dropout))
291
-
292
- if self.context_pre_only is not None and not self.context_pre_only:
293
- self.to_out_c = nn.Linear(self.inner_dim, dim)
294
-
295
- def forward(
296
- self,
297
- x: float['b n d'], # noised input x
298
- c: float['b n d'] = None, # context c
299
- mask: bool['b n'] | None = None,
300
- rope = None, # rotary position embedding for x
301
- c_rope = None, # rotary position embedding for c
302
- ) -> torch.Tensor:
303
- if c is not None:
304
- return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
- else:
306
- return self.processor(self, x, mask = mask, rope = rope)
307
-
308
-
309
- # Attention processor
310
-
311
- class AttnProcessor:
312
- def __init__(self):
313
- pass
314
-
315
- def __call__(
316
- self,
317
- attn: Attention,
318
- x: float['b n d'], # noised input x
319
- mask: bool['b n'] | None = None,
320
- rope = None, # rotary position embedding
321
- ) -> torch.FloatTensor:
322
-
323
- batch_size = x.shape[0]
324
-
325
- # `sample` projections.
326
- query = attn.to_q(x)
327
- key = attn.to_k(x)
328
- value = attn.to_v(x)
329
-
330
- # apply rotary position embedding
331
- if rope is not None:
332
- freqs, xpos_scale = rope
333
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
-
335
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
-
338
- # attention
339
- inner_dim = key.shape[-1]
340
- head_dim = inner_dim // attn.heads
341
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
-
345
- # mask. e.g. inference got a batch with different target durations, mask out the padding
346
- if mask is not None:
347
- attn_mask = mask
348
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
- else:
351
- attn_mask = None
352
-
353
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
- x = x.to(query.dtype)
356
-
357
- # linear proj
358
- x = attn.to_out[0](x)
359
- # dropout
360
- x = attn.to_out[1](x)
361
-
362
- if mask is not None:
363
- mask = rearrange(mask, 'b n -> b n 1')
364
- x = x.masked_fill(~mask, 0.)
365
-
366
- return x
367
-
368
-
369
- # Joint Attention processor for MM-DiT
370
- # modified from diffusers/src/diffusers/models/attention_processor.py
371
-
372
- class JointAttnProcessor:
373
- def __init__(self):
374
- pass
375
-
376
- def __call__(
377
- self,
378
- attn: Attention,
379
- x: float['b n d'], # noised input x
380
- c: float['b nt d'] = None, # context c, here text
381
- mask: bool['b n'] | None = None,
382
- rope = None, # rotary position embedding for x
383
- c_rope = None, # rotary position embedding for c
384
- ) -> torch.FloatTensor:
385
- residual = x
386
-
387
- batch_size = c.shape[0]
388
-
389
- # `sample` projections.
390
- query = attn.to_q(x)
391
- key = attn.to_k(x)
392
- value = attn.to_v(x)
393
-
394
- # `context` projections.
395
- c_query = attn.to_q_c(c)
396
- c_key = attn.to_k_c(c)
397
- c_value = attn.to_v_c(c)
398
-
399
- # apply rope for context and noised input independently
400
- if rope is not None:
401
- freqs, xpos_scale = rope
402
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
- if c_rope is not None:
406
- freqs, xpos_scale = c_rope
407
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
- c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
- c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
-
411
- # attention
412
- query = torch.cat([query, c_query], dim=1)
413
- key = torch.cat([key, c_key], dim=1)
414
- value = torch.cat([value, c_value], dim=1)
415
-
416
- inner_dim = key.shape[-1]
417
- head_dim = inner_dim // attn.heads
418
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
-
422
- # mask. e.g. inference got a batch with different target durations, mask out the padding
423
- if mask is not None:
424
- attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
- else:
428
- attn_mask = None
429
-
430
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
- x = x.to(query.dtype)
433
-
434
- # Split the attention outputs.
435
- x, c = (
436
- x[:, :residual.shape[1]],
437
- x[:, residual.shape[1]:],
438
- )
439
-
440
- # linear proj
441
- x = attn.to_out[0](x)
442
- # dropout
443
- x = attn.to_out[1](x)
444
- if not attn.context_pre_only:
445
- c = attn.to_out_c(c)
446
-
447
- if mask is not None:
448
- mask = rearrange(mask, 'b n -> b n 1')
449
- x = x.masked_fill(~mask, 0.)
450
- # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
-
452
- return x, c
453
-
454
-
455
- # DiT Block
456
-
457
- class DiTBlock(nn.Module):
458
-
459
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
- super().__init__()
461
-
462
- self.attn_norm = AdaLayerNormZero(dim)
463
- self.attn = Attention(
464
- processor = AttnProcessor(),
465
- dim = dim,
466
- heads = heads,
467
- dim_head = dim_head,
468
- dropout = dropout,
469
- )
470
-
471
- self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
- self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
-
474
- def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
- # pre-norm & modulation for attention input
476
- norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
-
478
- # attention
479
- attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
-
481
- # process attention output for input x
482
- x = x + gate_msa.unsqueeze(1) * attn_output
483
-
484
- norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
- ff_output = self.ff(norm)
486
- x = x + gate_mlp.unsqueeze(1) * ff_output
487
-
488
- return x
489
-
490
-
491
- # MMDiT Block https://arxiv.org/abs/2403.03206
492
-
493
- class MMDiTBlock(nn.Module):
494
- r"""
495
- modified from diffusers/src/diffusers/models/attention.py
496
-
497
- notes.
498
- _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
- _x: noised input related. (right part)
500
- context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
- """
502
-
503
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
- super().__init__()
505
-
506
- self.context_pre_only = context_pre_only
507
-
508
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
- self.attn_norm_x = AdaLayerNormZero(dim)
510
- self.attn = Attention(
511
- processor = JointAttnProcessor(),
512
- dim = dim,
513
- heads = heads,
514
- dim_head = dim_head,
515
- dropout = dropout,
516
- context_dim = dim,
517
- context_pre_only = context_pre_only,
518
- )
519
-
520
- if not context_pre_only:
521
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
- self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
- else:
524
- self.ff_norm_c = None
525
- self.ff_c = None
526
- self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
- self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
-
529
- def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
- # pre-norm & modulation for attention input
531
- if self.context_pre_only:
532
- norm_c = self.attn_norm_c(c, t)
533
- else:
534
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
-
537
- # attention
538
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
-
540
- # process attention output for context c
541
- if self.context_pre_only:
542
- c = None
543
- else: # if not last layer
544
- c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
-
546
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
- c_ff_output = self.ff_c(norm_c)
548
- c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
-
550
- # process attention output for input x
551
- x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
-
553
- norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
- x_ff_output = self.ff_x(norm_x)
555
- x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
-
557
- return c, x
558
-
559
-
560
- # time step conditioning embedding
561
-
562
- class TimestepEmbedding(nn.Module):
563
- def __init__(self, dim, freq_embed_dim=256):
564
- super().__init__()
565
- self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
- self.time_mlp = nn.Sequential(
567
- nn.Linear(freq_embed_dim, dim),
568
- nn.SiLU(),
569
- nn.Linear(dim, dim)
570
- )
571
-
572
- def forward(self, timestep: float['b']):
573
- time_hidden = self.time_embed(timestep)
574
- time = self.time_mlp(time_hidden) # b d
575
- return time