Staticaliza commited on
Commit
c38f823
1 Parent(s): 76f383f

Upload 3 files

Browse files
model/backbones/model_backbones_dit.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+ import torch.nn.functional as F
15
+
16
+ from einops import repeat
17
+
18
+ from x_transformers.x_transformers import RotaryEmbedding
19
+
20
+ from model.modules import (
21
+ TimestepEmbedding,
22
+ ConvNeXtV2Block,
23
+ ConvPositionEmbedding,
24
+ DiTBlock,
25
+ AdaLayerNormZero_Final,
26
+ precompute_freqs_cis, get_pos_embed_indices,
27
+ )
28
+
29
+
30
+ # Text embedding
31
+
32
+ class TextEmbedding(nn.Module):
33
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
+ super().__init__()
35
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
+
37
+ if conv_layers > 0:
38
+ self.extra_modeling = True
39
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
+ else:
43
+ self.extra_modeling = False
44
+
45
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
+ batch, text_len = text.shape[0], text.shape[1]
47
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
50
+
51
+ if drop_text: # cfg for text
52
+ text = torch.zeros_like(text)
53
+
54
+ text = self.text_embed(text) # b n -> b n d
55
+
56
+ # possible extra modeling
57
+ if self.extra_modeling:
58
+ # sinus pos emb
59
+ batch_start = torch.zeros((batch,), dtype=torch.long)
60
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
+ text_pos_embed = self.freqs_cis[pos_idx]
62
+ text = text + text_pos_embed
63
+
64
+ # convnextv2 blocks
65
+ text = self.text_blocks(text)
66
+
67
+ return text
68
+
69
+
70
+ # noised input audio and context mixing embedding
71
+
72
+ class InputEmbedding(nn.Module):
73
+ def __init__(self, mel_dim, text_dim, out_dim):
74
+ super().__init__()
75
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
+
78
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
+ if drop_audio_cond: # cfg for cond audio
80
+ cond = torch.zeros_like(cond)
81
+
82
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
+ x = self.conv_pos_embed(x) + x
84
+ return x
85
+
86
+
87
+ # Transformer backbone using DiT blocks
88
+
89
+ class DiT(nn.Module):
90
+ def __init__(self, *,
91
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
+ long_skip_connection = False,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.time_embed = TimestepEmbedding(dim)
98
+ if text_dim is None:
99
+ text_dim = mel_dim
100
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
+
103
+ self.rotary_embed = RotaryEmbedding(dim_head)
104
+
105
+ self.dim = dim
106
+ self.depth = depth
107
+
108
+ self.transformer_blocks = nn.ModuleList(
109
+ [
110
+ DiTBlock(
111
+ dim = dim,
112
+ heads = heads,
113
+ dim_head = dim_head,
114
+ ff_mult = ff_mult,
115
+ dropout = dropout
116
+ )
117
+ for _ in range(depth)
118
+ ]
119
+ )
120
+ self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
+
122
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
+ self.proj_out = nn.Linear(dim, mel_dim)
124
+
125
+ def forward(
126
+ self,
127
+ x: float['b n d'], # nosied input audio
128
+ cond: float['b n d'], # masked cond audio
129
+ text: int['b nt'], # text
130
+ time: float['b'] | float[''], # time step
131
+ drop_audio_cond, # cfg for cond audio
132
+ drop_text, # cfg for text
133
+ mask: bool['b n'] | None = None,
134
+ ):
135
+ batch, seq_len = x.shape[0], x.shape[1]
136
+ if time.ndim == 0:
137
+ time = repeat(time, ' -> b', b = batch)
138
+
139
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
+ t = self.time_embed(time)
141
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
+
144
+ rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
+
146
+ if self.long_skip_connection is not None:
147
+ residual = x
148
+
149
+ for block in self.transformer_blocks:
150
+ x = block(x, t, mask = mask, rope = rope)
151
+
152
+ if self.long_skip_connection is not None:
153
+ x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
+
155
+ x = self.norm_out(x, t)
156
+ output = self.proj_out(x)
157
+
158
+ return output
model/backbones/model_backbones_mmdit.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from einops import repeat
16
+
17
+ from x_transformers.x_transformers import RotaryEmbedding
18
+
19
+ from model.modules import (
20
+ TimestepEmbedding,
21
+ ConvPositionEmbedding,
22
+ MMDiTBlock,
23
+ AdaLayerNormZero_Final,
24
+ precompute_freqs_cis, get_pos_embed_indices,
25
+ )
26
+
27
+
28
+ # text embedding
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+ class AudioEmbedding(nn.Module):
58
+ def __init__(self, in_dim, out_dim):
59
+ super().__init__()
60
+ self.linear = nn.Linear(2 * in_dim, out_dim)
61
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
+
63
+ def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
+ if drop_audio_cond:
65
+ cond = torch.zeros_like(cond)
66
+ x = torch.cat((x, cond), dim = -1)
67
+ x = self.linear(x)
68
+ x = self.conv_pos_embed(x) + x
69
+ return x
70
+
71
+
72
+ # Transformer backbone using MM-DiT blocks
73
+
74
+ class MMDiT(nn.Module):
75
+ def __init__(self, *,
76
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
+ text_num_embeds = 256, mel_dim = 100,
78
+ ):
79
+ super().__init__()
80
+
81
+ self.time_embed = TimestepEmbedding(dim)
82
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
83
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
84
+
85
+ self.rotary_embed = RotaryEmbedding(dim_head)
86
+
87
+ self.dim = dim
88
+ self.depth = depth
89
+
90
+ self.transformer_blocks = nn.ModuleList(
91
+ [
92
+ MMDiTBlock(
93
+ dim = dim,
94
+ heads = heads,
95
+ dim_head = dim_head,
96
+ dropout = dropout,
97
+ ff_mult = ff_mult,
98
+ context_pre_only = i == depth - 1,
99
+ )
100
+ for i in range(depth)
101
+ ]
102
+ )
103
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
+ self.proj_out = nn.Linear(dim, mel_dim)
105
+
106
+ def forward(
107
+ self,
108
+ x: float['b n d'], # nosied input audio
109
+ cond: float['b n d'], # masked cond audio
110
+ text: int['b nt'], # text
111
+ time: float['b'] | float[''], # time step
112
+ drop_audio_cond, # cfg for cond audio
113
+ drop_text, # cfg for text
114
+ mask: bool['b n'] | None = None,
115
+ ):
116
+ batch = x.shape[0]
117
+ if time.ndim == 0:
118
+ time = repeat(time, ' -> b', b = batch)
119
+
120
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
+ t = self.time_embed(time)
122
+ c = self.text_embed(text, drop_text = drop_text)
123
+ x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
+
125
+ seq_len = x.shape[1]
126
+ text_len = text.shape[1]
127
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
+
130
+ for block in self.transformer_blocks:
131
+ c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
+
133
+ x = self.norm_out(x, t)
134
+ output = self.proj_out(x)
135
+
136
+ return output
model/backbones/model_backbones_unett.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Literal
12
+
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+
17
+ from einops import repeat, pack, unpack
18
+
19
+ from x_transformers import RMSNorm
20
+ from x_transformers.x_transformers import RotaryEmbedding
21
+
22
+ from model.modules import (
23
+ TimestepEmbedding,
24
+ ConvNeXtV2Block,
25
+ ConvPositionEmbedding,
26
+ Attention,
27
+ AttnProcessor,
28
+ FeedForward,
29
+ precompute_freqs_cis, get_pos_embed_indices,
30
+ )
31
+
32
+
33
+ # Text embedding
34
+
35
+ class TextEmbedding(nn.Module):
36
+ def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
+ super().__init__()
38
+ self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
+
40
+ if conv_layers > 0:
41
+ self.extra_modeling = True
42
+ self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
+ self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
+ self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
+ else:
46
+ self.extra_modeling = False
47
+
48
+ def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
+ batch, text_len = text.shape[0], text.shape[1]
50
+ text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
+ text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
+ text = F.pad(text, (0, seq_len - text_len), value = 0)
53
+
54
+ if drop_text: # cfg for text
55
+ text = torch.zeros_like(text)
56
+
57
+ text = self.text_embed(text) # b n -> b n d
58
+
59
+ # possible extra modeling
60
+ if self.extra_modeling:
61
+ # sinus pos emb
62
+ batch_start = torch.zeros((batch,), dtype=torch.long)
63
+ pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
+ text_pos_embed = self.freqs_cis[pos_idx]
65
+ text = text + text_pos_embed
66
+
67
+ # convnextv2 blocks
68
+ text = self.text_blocks(text)
69
+
70
+ return text
71
+
72
+
73
+ # noised input audio and context mixing embedding
74
+
75
+ class InputEmbedding(nn.Module):
76
+ def __init__(self, mel_dim, text_dim, out_dim):
77
+ super().__init__()
78
+ self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
+ self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
+
81
+ def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
+ if drop_audio_cond: # cfg for cond audio
83
+ cond = torch.zeros_like(cond)
84
+
85
+ x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
+ x = self.conv_pos_embed(x) + x
87
+ return x
88
+
89
+
90
+ # Flat UNet Transformer backbone
91
+
92
+ class UNetT(nn.Module):
93
+ def __init__(self, *,
94
+ dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
+ mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
+ skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
+ ):
98
+ super().__init__()
99
+ assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
+
101
+ self.time_embed = TimestepEmbedding(dim)
102
+ if text_dim is None:
103
+ text_dim = mel_dim
104
+ self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
+ self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
+
107
+ self.rotary_embed = RotaryEmbedding(dim_head)
108
+
109
+ # transformer layers & skip connections
110
+
111
+ self.dim = dim
112
+ self.skip_connect_type = skip_connect_type
113
+ needs_skip_proj = skip_connect_type == 'concat'
114
+
115
+ self.depth = depth
116
+ self.layers = nn.ModuleList([])
117
+
118
+ for idx in range(depth):
119
+ is_later_half = idx >= (depth // 2)
120
+
121
+ attn_norm = RMSNorm(dim)
122
+ attn = Attention(
123
+ processor = AttnProcessor(),
124
+ dim = dim,
125
+ heads = heads,
126
+ dim_head = dim_head,
127
+ dropout = dropout,
128
+ )
129
+
130
+ ff_norm = RMSNorm(dim)
131
+ ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
+
133
+ skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
+
135
+ self.layers.append(nn.ModuleList([
136
+ skip_proj,
137
+ attn_norm,
138
+ attn,
139
+ ff_norm,
140
+ ff,
141
+ ]))
142
+
143
+ self.norm_out = RMSNorm(dim)
144
+ self.proj_out = nn.Linear(dim, mel_dim)
145
+
146
+ def forward(
147
+ self,
148
+ x: float['b n d'], # nosied input audio
149
+ cond: float['b n d'], # masked cond audio
150
+ text: int['b nt'], # text
151
+ time: float['b'] | float[''], # time step
152
+ drop_audio_cond, # cfg for cond audio
153
+ drop_text, # cfg for text
154
+ mask: bool['b n'] | None = None,
155
+ ):
156
+ batch, seq_len = x.shape[0], x.shape[1]
157
+ if time.ndim == 0:
158
+ time = repeat(time, ' -> b', b = batch)
159
+
160
+ # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
+ t = self.time_embed(time)
162
+ text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
+ x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
+
165
+ # postfix time t to input x, [b n d] -> [b n+1 d]
166
+ x, ps = pack((t, x), 'b * d')
167
+ if mask is not None:
168
+ mask = F.pad(mask, (1, 0), value=1)
169
+
170
+ rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
+
172
+ # flat unet transformer
173
+ skip_connect_type = self.skip_connect_type
174
+ skips = []
175
+ for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
+ layer = idx + 1
177
+
178
+ # skip connection logic
179
+ is_first_half = layer <= (self.depth // 2)
180
+ is_later_half = not is_first_half
181
+
182
+ if is_first_half:
183
+ skips.append(x)
184
+
185
+ if is_later_half:
186
+ skip = skips.pop()
187
+ if skip_connect_type == 'concat':
188
+ x = torch.cat((x, skip), dim = -1)
189
+ x = maybe_skip_proj(x)
190
+ elif skip_connect_type == 'add':
191
+ x = x + skip
192
+
193
+ # attention and feedforward blocks
194
+ x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
+ x = ff(ff_norm(x)) + x
196
+
197
+ assert len(skips) == 0
198
+
199
+ _, x = unpack(self.norm_out(x), ps, 'b * d')
200
+
201
+ return self.proj_out(x)