Staticaliza commited on
Commit
1d1273d
1 Parent(s): fcf7ece

Delete model

Browse files
model/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from model.cfm import CFM
2
-
3
- from model.backbones.unett import UNetT
4
- from model.backbones.dit import DiT
5
- from model.backbones.mmdit import MMDiT
6
-
7
- from model.trainer import Trainer
 
 
 
 
 
 
 
 
model/backbones/dit.py DELETED
@@ -1,158 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import torch
13
- from torch import nn
14
- import torch.nn.functional as F
15
-
16
- from einops import repeat
17
-
18
- from x_transformers.x_transformers import RotaryEmbedding
19
-
20
- from model.modules import (
21
- TimestepEmbedding,
22
- ConvNeXtV2Block,
23
- ConvPositionEmbedding,
24
- DiTBlock,
25
- AdaLayerNormZero_Final,
26
- precompute_freqs_cis, get_pos_embed_indices,
27
- )
28
-
29
-
30
- # Text embedding
31
-
32
- class TextEmbedding(nn.Module):
33
- def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
34
- super().__init__()
35
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
36
-
37
- if conv_layers > 0:
38
- self.extra_modeling = True
39
- self.precompute_max_pos = 4096 # ~44s of 24khz audio
40
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
41
- self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
42
- else:
43
- self.extra_modeling = False
44
-
45
- def forward(self, text: int['b nt'], seq_len, drop_text = False):
46
- batch, text_len = text.shape[0], text.shape[1]
47
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
48
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
49
- text = F.pad(text, (0, seq_len - text_len), value = 0)
50
-
51
- if drop_text: # cfg for text
52
- text = torch.zeros_like(text)
53
-
54
- text = self.text_embed(text) # b n -> b n d
55
-
56
- # possible extra modeling
57
- if self.extra_modeling:
58
- # sinus pos emb
59
- batch_start = torch.zeros((batch,), dtype=torch.long)
60
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
61
- text_pos_embed = self.freqs_cis[pos_idx]
62
- text = text + text_pos_embed
63
-
64
- # convnextv2 blocks
65
- text = self.text_blocks(text)
66
-
67
- return text
68
-
69
-
70
- # noised input audio and context mixing embedding
71
-
72
- class InputEmbedding(nn.Module):
73
- def __init__(self, mel_dim, text_dim, out_dim):
74
- super().__init__()
75
- self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
76
- self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
77
-
78
- def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
79
- if drop_audio_cond: # cfg for cond audio
80
- cond = torch.zeros_like(cond)
81
-
82
- x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
83
- x = self.conv_pos_embed(x) + x
84
- return x
85
-
86
-
87
- # Transformer backbone using DiT blocks
88
-
89
- class DiT(nn.Module):
90
- def __init__(self, *,
91
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
92
- mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
93
- long_skip_connection = False,
94
- ):
95
- super().__init__()
96
-
97
- self.time_embed = TimestepEmbedding(dim)
98
- if text_dim is None:
99
- text_dim = mel_dim
100
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
101
- self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
102
-
103
- self.rotary_embed = RotaryEmbedding(dim_head)
104
-
105
- self.dim = dim
106
- self.depth = depth
107
-
108
- self.transformer_blocks = nn.ModuleList(
109
- [
110
- DiTBlock(
111
- dim = dim,
112
- heads = heads,
113
- dim_head = dim_head,
114
- ff_mult = ff_mult,
115
- dropout = dropout
116
- )
117
- for _ in range(depth)
118
- ]
119
- )
120
- self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
121
-
122
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
123
- self.proj_out = nn.Linear(dim, mel_dim)
124
-
125
- def forward(
126
- self,
127
- x: float['b n d'], # nosied input audio
128
- cond: float['b n d'], # masked cond audio
129
- text: int['b nt'], # text
130
- time: float['b'] | float[''], # time step
131
- drop_audio_cond, # cfg for cond audio
132
- drop_text, # cfg for text
133
- mask: bool['b n'] | None = None,
134
- ):
135
- batch, seq_len = x.shape[0], x.shape[1]
136
- if time.ndim == 0:
137
- time = repeat(time, ' -> b', b = batch)
138
-
139
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
140
- t = self.time_embed(time)
141
- text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
142
- x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
143
-
144
- rope = self.rotary_embed.forward_from_seq_len(seq_len)
145
-
146
- if self.long_skip_connection is not None:
147
- residual = x
148
-
149
- for block in self.transformer_blocks:
150
- x = block(x, t, mask = mask, rope = rope)
151
-
152
- if self.long_skip_connection is not None:
153
- x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
154
-
155
- x = self.norm_out(x, t)
156
- output = self.proj_out(x)
157
-
158
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/backbones/mmdit.py DELETED
@@ -1,136 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
-
12
- import torch
13
- from torch import nn
14
-
15
- from einops import repeat
16
-
17
- from x_transformers.x_transformers import RotaryEmbedding
18
-
19
- from model.modules import (
20
- TimestepEmbedding,
21
- ConvPositionEmbedding,
22
- MMDiTBlock,
23
- AdaLayerNormZero_Final,
24
- precompute_freqs_cis, get_pos_embed_indices,
25
- )
26
-
27
-
28
- # text embedding
29
-
30
- class TextEmbedding(nn.Module):
31
- def __init__(self, out_dim, text_num_embeds):
32
- super().__init__()
33
- self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
-
35
- self.precompute_max_pos = 1024
36
- self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
-
38
- def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
39
- text = text + 1
40
- if drop_text:
41
- text = torch.zeros_like(text)
42
- text = self.text_embed(text)
43
-
44
- # sinus pos emb
45
- batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
- batch_text_len = text.shape[1]
47
- pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
- text_pos_embed = self.freqs_cis[pos_idx]
49
-
50
- text = text + text_pos_embed
51
-
52
- return text
53
-
54
-
55
- # noised input & masked cond audio embedding
56
-
57
- class AudioEmbedding(nn.Module):
58
- def __init__(self, in_dim, out_dim):
59
- super().__init__()
60
- self.linear = nn.Linear(2 * in_dim, out_dim)
61
- self.conv_pos_embed = ConvPositionEmbedding(out_dim)
62
-
63
- def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
64
- if drop_audio_cond:
65
- cond = torch.zeros_like(cond)
66
- x = torch.cat((x, cond), dim = -1)
67
- x = self.linear(x)
68
- x = self.conv_pos_embed(x) + x
69
- return x
70
-
71
-
72
- # Transformer backbone using MM-DiT blocks
73
-
74
- class MMDiT(nn.Module):
75
- def __init__(self, *,
76
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
77
- text_num_embeds = 256, mel_dim = 100,
78
- ):
79
- super().__init__()
80
-
81
- self.time_embed = TimestepEmbedding(dim)
82
- self.text_embed = TextEmbedding(dim, text_num_embeds)
83
- self.audio_embed = AudioEmbedding(mel_dim, dim)
84
-
85
- self.rotary_embed = RotaryEmbedding(dim_head)
86
-
87
- self.dim = dim
88
- self.depth = depth
89
-
90
- self.transformer_blocks = nn.ModuleList(
91
- [
92
- MMDiTBlock(
93
- dim = dim,
94
- heads = heads,
95
- dim_head = dim_head,
96
- dropout = dropout,
97
- ff_mult = ff_mult,
98
- context_pre_only = i == depth - 1,
99
- )
100
- for i in range(depth)
101
- ]
102
- )
103
- self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
104
- self.proj_out = nn.Linear(dim, mel_dim)
105
-
106
- def forward(
107
- self,
108
- x: float['b n d'], # nosied input audio
109
- cond: float['b n d'], # masked cond audio
110
- text: int['b nt'], # text
111
- time: float['b'] | float[''], # time step
112
- drop_audio_cond, # cfg for cond audio
113
- drop_text, # cfg for text
114
- mask: bool['b n'] | None = None,
115
- ):
116
- batch = x.shape[0]
117
- if time.ndim == 0:
118
- time = repeat(time, ' -> b', b = batch)
119
-
120
- # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
121
- t = self.time_embed(time)
122
- c = self.text_embed(text, drop_text = drop_text)
123
- x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
124
-
125
- seq_len = x.shape[1]
126
- text_len = text.shape[1]
127
- rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
128
- rope_text = self.rotary_embed.forward_from_seq_len(text_len)
129
-
130
- for block in self.transformer_blocks:
131
- c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
132
-
133
- x = self.norm_out(x, t)
134
- output = self.proj_out(x)
135
-
136
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/backbones/unett.py DELETED
@@ -1,201 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
- from typing import Literal
12
-
13
- import torch
14
- from torch import nn
15
- import torch.nn.functional as F
16
-
17
- from einops import repeat, pack, unpack
18
-
19
- from x_transformers import RMSNorm
20
- from x_transformers.x_transformers import RotaryEmbedding
21
-
22
- from model.modules import (
23
- TimestepEmbedding,
24
- ConvNeXtV2Block,
25
- ConvPositionEmbedding,
26
- Attention,
27
- AttnProcessor,
28
- FeedForward,
29
- precompute_freqs_cis, get_pos_embed_indices,
30
- )
31
-
32
-
33
- # Text embedding
34
-
35
- class TextEmbedding(nn.Module):
36
- def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
37
- super().__init__()
38
- self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
39
-
40
- if conv_layers > 0:
41
- self.extra_modeling = True
42
- self.precompute_max_pos = 4096 # ~44s of 24khz audio
43
- self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
44
- self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
45
- else:
46
- self.extra_modeling = False
47
-
48
- def forward(self, text: int['b nt'], seq_len, drop_text = False):
49
- batch, text_len = text.shape[0], text.shape[1]
50
- text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
51
- text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
52
- text = F.pad(text, (0, seq_len - text_len), value = 0)
53
-
54
- if drop_text: # cfg for text
55
- text = torch.zeros_like(text)
56
-
57
- text = self.text_embed(text) # b n -> b n d
58
-
59
- # possible extra modeling
60
- if self.extra_modeling:
61
- # sinus pos emb
62
- batch_start = torch.zeros((batch,), dtype=torch.long)
63
- pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
64
- text_pos_embed = self.freqs_cis[pos_idx]
65
- text = text + text_pos_embed
66
-
67
- # convnextv2 blocks
68
- text = self.text_blocks(text)
69
-
70
- return text
71
-
72
-
73
- # noised input audio and context mixing embedding
74
-
75
- class InputEmbedding(nn.Module):
76
- def __init__(self, mel_dim, text_dim, out_dim):
77
- super().__init__()
78
- self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
79
- self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
80
-
81
- def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
82
- if drop_audio_cond: # cfg for cond audio
83
- cond = torch.zeros_like(cond)
84
-
85
- x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
86
- x = self.conv_pos_embed(x) + x
87
- return x
88
-
89
-
90
- # Flat UNet Transformer backbone
91
-
92
- class UNetT(nn.Module):
93
- def __init__(self, *,
94
- dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
95
- mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
96
- skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
97
- ):
98
- super().__init__()
99
- assert depth % 2 == 0, "UNet-Transformer's depth should be even."
100
-
101
- self.time_embed = TimestepEmbedding(dim)
102
- if text_dim is None:
103
- text_dim = mel_dim
104
- self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
105
- self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
106
-
107
- self.rotary_embed = RotaryEmbedding(dim_head)
108
-
109
- # transformer layers & skip connections
110
-
111
- self.dim = dim
112
- self.skip_connect_type = skip_connect_type
113
- needs_skip_proj = skip_connect_type == 'concat'
114
-
115
- self.depth = depth
116
- self.layers = nn.ModuleList([])
117
-
118
- for idx in range(depth):
119
- is_later_half = idx >= (depth // 2)
120
-
121
- attn_norm = RMSNorm(dim)
122
- attn = Attention(
123
- processor = AttnProcessor(),
124
- dim = dim,
125
- heads = heads,
126
- dim_head = dim_head,
127
- dropout = dropout,
128
- )
129
-
130
- ff_norm = RMSNorm(dim)
131
- ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
132
-
133
- skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
134
-
135
- self.layers.append(nn.ModuleList([
136
- skip_proj,
137
- attn_norm,
138
- attn,
139
- ff_norm,
140
- ff,
141
- ]))
142
-
143
- self.norm_out = RMSNorm(dim)
144
- self.proj_out = nn.Linear(dim, mel_dim)
145
-
146
- def forward(
147
- self,
148
- x: float['b n d'], # nosied input audio
149
- cond: float['b n d'], # masked cond audio
150
- text: int['b nt'], # text
151
- time: float['b'] | float[''], # time step
152
- drop_audio_cond, # cfg for cond audio
153
- drop_text, # cfg for text
154
- mask: bool['b n'] | None = None,
155
- ):
156
- batch, seq_len = x.shape[0], x.shape[1]
157
- if time.ndim == 0:
158
- time = repeat(time, ' -> b', b = batch)
159
-
160
- # t: conditioning time, c: context (text + masked cond audio), x: noised input audio
161
- t = self.time_embed(time)
162
- text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
163
- x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
164
-
165
- # postfix time t to input x, [b n d] -> [b n+1 d]
166
- x, ps = pack((t, x), 'b * d')
167
- if mask is not None:
168
- mask = F.pad(mask, (1, 0), value=1)
169
-
170
- rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
171
-
172
- # flat unet transformer
173
- skip_connect_type = self.skip_connect_type
174
- skips = []
175
- for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
176
- layer = idx + 1
177
-
178
- # skip connection logic
179
- is_first_half = layer <= (self.depth // 2)
180
- is_later_half = not is_first_half
181
-
182
- if is_first_half:
183
- skips.append(x)
184
-
185
- if is_later_half:
186
- skip = skips.pop()
187
- if skip_connect_type == 'concat':
188
- x = torch.cat((x, skip), dim = -1)
189
- x = maybe_skip_proj(x)
190
- elif skip_connect_type == 'add':
191
- x = x + skip
192
-
193
- # attention and feedforward blocks
194
- x = attn(attn_norm(x), rope = rope, mask = mask) + x
195
- x = ff(ff_norm(x)) + x
196
-
197
- assert len(skips) == 0
198
-
199
- _, x = unpack(self.norm_out(x), ps, 'b * d')
200
-
201
- return self.proj_out(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/cfm.py DELETED
@@ -1,279 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
- from typing import Callable
12
- from random import random
13
-
14
- import torch
15
- from torch import nn
16
- import torch.nn.functional as F
17
- from torch.nn.utils.rnn import pad_sequence
18
-
19
- from torchdiffeq import odeint
20
-
21
- from einops import rearrange
22
-
23
- from model.modules import MelSpec
24
-
25
- from model.utils import (
26
- default, exists,
27
- list_str_to_idx, list_str_to_tensor,
28
- lens_to_mask, mask_from_frac_lengths,
29
- )
30
-
31
-
32
- class CFM(nn.Module):
33
- def __init__(
34
- self,
35
- transformer: nn.Module,
36
- sigma = 0.,
37
- odeint_kwargs: dict = dict(
38
- # atol = 1e-5,
39
- # rtol = 1e-5,
40
- method = 'euler' # 'midpoint'
41
- ),
42
- audio_drop_prob = 0.3,
43
- cond_drop_prob = 0.2,
44
- num_channels = None,
45
- mel_spec_module: nn.Module | None = None,
46
- mel_spec_kwargs: dict = dict(),
47
- frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
- vocab_char_map: dict[str: int] | None = None
49
- ):
50
- super().__init__()
51
-
52
- self.frac_lengths_mask = frac_lengths_mask
53
-
54
- # mel spec
55
- self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
- num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
- self.num_channels = num_channels
58
-
59
- # classifier-free guidance
60
- self.audio_drop_prob = audio_drop_prob
61
- self.cond_drop_prob = cond_drop_prob
62
-
63
- # transformer
64
- self.transformer = transformer
65
- dim = transformer.dim
66
- self.dim = dim
67
-
68
- # conditional flow related
69
- self.sigma = sigma
70
-
71
- # sampling related
72
- self.odeint_kwargs = odeint_kwargs
73
-
74
- # vocab map for tokenization
75
- self.vocab_char_map = vocab_char_map
76
-
77
- @property
78
- def device(self):
79
- return next(self.parameters()).device
80
-
81
- @torch.no_grad()
82
- def sample(
83
- self,
84
- cond: float['b n d'] | float['b nw'],
85
- text: int['b nt'] | list[str],
86
- duration: int | int['b'],
87
- *,
88
- lens: int['b'] | None = None,
89
- steps = 32,
90
- cfg_strength = 1.,
91
- sway_sampling_coef = None,
92
- seed: int | None = None,
93
- max_duration = 4096,
94
- vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
- no_ref_audio = False,
96
- duplicate_test = False,
97
- t_inter = 0.1,
98
- edit_mask = None,
99
- ):
100
- self.eval()
101
-
102
- # raw wave
103
-
104
- if cond.ndim == 2:
105
- cond = self.mel_spec(cond)
106
- cond = rearrange(cond, 'b d n -> b n d')
107
- assert cond.shape[-1] == self.num_channels
108
-
109
- batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
- if not exists(lens):
111
- lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
-
113
- # text
114
-
115
- if isinstance(text, list):
116
- if exists(self.vocab_char_map):
117
- text = list_str_to_idx(text, self.vocab_char_map).to(device)
118
- else:
119
- text = list_str_to_tensor(text).to(device)
120
- assert text.shape[0] == batch
121
-
122
- if exists(text):
123
- text_lens = (text != -1).sum(dim = -1)
124
- lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
-
126
- # duration
127
-
128
- cond_mask = lens_to_mask(lens)
129
- if edit_mask is not None:
130
- cond_mask = cond_mask & edit_mask
131
-
132
- if isinstance(duration, int):
133
- duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
-
135
- duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
- duration = duration.clamp(max = max_duration)
137
- max_duration = duration.amax()
138
-
139
- # duplicate test corner for inner time step oberservation
140
- if duplicate_test:
141
- test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
-
143
- cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
- cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
- cond_mask = rearrange(cond_mask, '... -> ... 1')
146
- step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
-
148
- if batch > 1:
149
- mask = lens_to_mask(duration)
150
- else: # save memory and speed up, as single inference need no mask currently
151
- mask = None
152
-
153
- # test for no ref audio
154
- if no_ref_audio:
155
- cond = torch.zeros_like(cond)
156
-
157
- # neural ode
158
-
159
- def fn(t, x):
160
- # at each step, conditioning is fixed
161
- # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
-
163
- # predict flow
164
- pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
165
- if cfg_strength < 1e-5:
166
- return pred
167
-
168
- null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
169
- return pred + (pred - null_pred) * cfg_strength
170
-
171
- # noise input
172
- # to make sure batch inference result is same with different batch size, and for sure single inference
173
- # still some difference maybe due to convolutional layers
174
- y0 = []
175
- for dur in duration:
176
- if exists(seed):
177
- torch.manual_seed(seed)
178
- y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
- y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
-
181
- t_start = 0
182
-
183
- # duplicate test corner for inner time step oberservation
184
- if duplicate_test:
185
- t_start = t_inter
186
- y0 = (1 - t_start) * y0 + t_start * test_cond
187
- steps = int(steps * (1 - t_start))
188
-
189
- t = torch.linspace(t_start, 1, steps, device = self.device)
190
- if sway_sampling_coef is not None:
191
- t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
-
193
- trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
-
195
- sampled = trajectory[-1]
196
- out = sampled
197
- out = torch.where(cond_mask, cond, out)
198
-
199
- if exists(vocoder):
200
- out = rearrange(out, 'b n d -> b d n')
201
- out = vocoder(out)
202
-
203
- return out, trajectory
204
-
205
- def forward(
206
- self,
207
- inp: float['b n d'] | float['b nw'], # mel or raw wave
208
- text: int['b nt'] | list[str],
209
- *,
210
- lens: int['b'] | None = None,
211
- noise_scheduler: str | None = None,
212
- ):
213
- # handle raw wave
214
- if inp.ndim == 2:
215
- inp = self.mel_spec(inp)
216
- inp = rearrange(inp, 'b d n -> b n d')
217
- assert inp.shape[-1] == self.num_channels
218
-
219
- batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
-
221
- # handle text as string
222
- if isinstance(text, list):
223
- if exists(self.vocab_char_map):
224
- text = list_str_to_idx(text, self.vocab_char_map).to(device)
225
- else:
226
- text = list_str_to_tensor(text).to(device)
227
- assert text.shape[0] == batch
228
-
229
- # lens and mask
230
- if not exists(lens):
231
- lens = torch.full((batch,), seq_len, device = device)
232
-
233
- mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
-
235
- # get a random span to mask out for training conditionally
236
- frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
- rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
-
239
- if exists(mask):
240
- rand_span_mask &= mask
241
-
242
- # mel is x1
243
- x1 = inp
244
-
245
- # x0 is gaussian noise
246
- x0 = torch.randn_like(x1)
247
-
248
- # time step
249
- time = torch.rand((batch,), dtype = dtype, device = self.device)
250
- # TODO. noise_scheduler
251
-
252
- # sample xt (φ_t(x) in the paper)
253
- t = rearrange(time, 'b -> b 1 1')
254
- φ = (1 - t) * x0 + t * x1
255
- flow = x1 - x0
256
-
257
- # only predict what is within the random mask span for infilling
258
- cond = torch.where(
259
- rand_span_mask[..., None],
260
- torch.zeros_like(x1), x1
261
- )
262
-
263
- # transformer and cfg training with a drop rate
264
- drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
265
- if random() < self.cond_drop_prob: # p_uncond in voicebox paper
266
- drop_audio_cond = True
267
- drop_text = True
268
- else:
269
- drop_text = False
270
-
271
- # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
- # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
- pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
274
-
275
- # flow matching loss
276
- loss = F.mse_loss(pred, flow, reduction = 'none')
277
- loss = loss[rand_span_mask]
278
-
279
- return loss.mean(), cond, pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/dataset.py DELETED
@@ -1,257 +0,0 @@
1
- import json
2
- import random
3
- from tqdm import tqdm
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch.utils.data import Dataset, Sampler
8
- import torchaudio
9
- from datasets import load_dataset, load_from_disk
10
- from datasets import Dataset as Dataset_
11
-
12
- from einops import rearrange
13
-
14
- from model.modules import MelSpec
15
-
16
-
17
- class HFDataset(Dataset):
18
- def __init__(
19
- self,
20
- hf_dataset: Dataset,
21
- target_sample_rate = 24_000,
22
- n_mel_channels = 100,
23
- hop_length = 256,
24
- ):
25
- self.data = hf_dataset
26
- self.target_sample_rate = target_sample_rate
27
- self.hop_length = hop_length
28
- self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
-
30
- def get_frame_len(self, index):
31
- row = self.data[index]
32
- audio = row['audio']['array']
33
- sample_rate = row['audio']['sampling_rate']
34
- return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
-
36
- def __len__(self):
37
- return len(self.data)
38
-
39
- def __getitem__(self, index):
40
- row = self.data[index]
41
- audio = row['audio']['array']
42
-
43
- # logger.info(f"Audio shape: {audio.shape}")
44
-
45
- sample_rate = row['audio']['sampling_rate']
46
- duration = audio.shape[-1] / sample_rate
47
-
48
- if duration > 30 or duration < 0.3:
49
- return self.__getitem__((index + 1) % len(self.data))
50
-
51
- audio_tensor = torch.from_numpy(audio).float()
52
-
53
- if sample_rate != self.target_sample_rate:
54
- resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
- audio_tensor = resampler(audio_tensor)
56
-
57
- audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
-
59
- mel_spec = self.mel_spectrogram(audio_tensor)
60
-
61
- mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
-
63
- text = row['text']
64
-
65
- return dict(
66
- mel_spec = mel_spec,
67
- text = text,
68
- )
69
-
70
-
71
- class CustomDataset(Dataset):
72
- def __init__(
73
- self,
74
- custom_dataset: Dataset,
75
- durations = None,
76
- target_sample_rate = 24_000,
77
- hop_length = 256,
78
- n_mel_channels = 100,
79
- preprocessed_mel = False,
80
- ):
81
- self.data = custom_dataset
82
- self.durations = durations
83
- self.target_sample_rate = target_sample_rate
84
- self.hop_length = hop_length
85
- self.preprocessed_mel = preprocessed_mel
86
- if not preprocessed_mel:
87
- self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
-
89
- def get_frame_len(self, index):
90
- if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
- return self.durations[index] * self.target_sample_rate / self.hop_length
92
- return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
-
94
- def __len__(self):
95
- return len(self.data)
96
-
97
- def __getitem__(self, index):
98
- row = self.data[index]
99
- audio_path = row["audio_path"]
100
- text = row["text"]
101
- duration = row["duration"]
102
-
103
- if self.preprocessed_mel:
104
- mel_spec = torch.tensor(row["mel_spec"])
105
-
106
- else:
107
- audio, source_sample_rate = torchaudio.load(audio_path)
108
-
109
- if duration > 30 or duration < 0.3:
110
- return self.__getitem__((index + 1) % len(self.data))
111
-
112
- if source_sample_rate != self.target_sample_rate:
113
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
- audio = resampler(audio)
115
-
116
- mel_spec = self.mel_spectrogram(audio)
117
- mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
-
119
- return dict(
120
- mel_spec = mel_spec,
121
- text = text,
122
- )
123
-
124
-
125
- # Dynamic Batch Sampler
126
-
127
- class DynamicBatchSampler(Sampler[list[int]]):
128
- """ Extension of Sampler that will do the following:
129
- 1. Change the batch size (essentially number of sequences)
130
- in a batch to ensure that the total number of frames are less
131
- than a certain threshold.
132
- 2. Make sure the padding efficiency in the batch is high.
133
- """
134
-
135
- def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
- self.sampler = sampler
137
- self.frames_threshold = frames_threshold
138
- self.max_samples = max_samples
139
-
140
- indices, batches = [], []
141
- data_source = self.sampler.data_source
142
-
143
- for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
- indices.append((idx, data_source.get_frame_len(idx)))
145
- indices.sort(key=lambda elem : elem[1])
146
-
147
- batch = []
148
- batch_frames = 0
149
- for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
- if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
- batch.append(idx)
152
- batch_frames += frame_len
153
- else:
154
- if len(batch) > 0:
155
- batches.append(batch)
156
- if frame_len <= self.frames_threshold:
157
- batch = [idx]
158
- batch_frames = frame_len
159
- else:
160
- batch = []
161
- batch_frames = 0
162
-
163
- if not drop_last and len(batch) > 0:
164
- batches.append(batch)
165
-
166
- del indices
167
-
168
- # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
- # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
- # e.g. for epoch n, use (random_seed + n)
171
- random.seed(random_seed)
172
- random.shuffle(batches)
173
-
174
- self.batches = batches
175
-
176
- def __iter__(self):
177
- return iter(self.batches)
178
-
179
- def __len__(self):
180
- return len(self.batches)
181
-
182
-
183
- # Load dataset
184
-
185
- def load_dataset(
186
- dataset_name: str,
187
- tokenizer: str = "pinyin",
188
- dataset_type: str = "CustomDataset",
189
- audio_type: str = "raw",
190
- mel_spec_kwargs: dict = dict()
191
- ) -> CustomDataset | HFDataset:
192
- '''
193
- dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
- - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
- '''
196
-
197
- print("Loading dataset ...")
198
-
199
- if dataset_type == "CustomDataset":
200
- if audio_type == "raw":
201
- try:
202
- train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
203
- except:
204
- train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
205
- preprocessed_mel = False
206
- elif audio_type == "mel":
207
- train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
208
- preprocessed_mel = True
209
- with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
210
- data_dict = json.load(f)
211
- durations = data_dict["duration"]
212
- train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
-
214
- elif dataset_type == "CustomDatasetPath":
215
- try:
216
- train_dataset = load_from_disk(f"{dataset_name}/raw")
217
- except:
218
- train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
-
220
- with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
- data_dict = json.load(f)
222
- durations = data_dict["duration"]
223
- train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
-
225
- elif dataset_type == "HFDataset":
226
- print("Should manually modify the path of huggingface dataset to your need.\n" +
227
- "May also the corresponding script cuz different dataset may have different format.")
228
- pre, post = dataset_name.split("_")
229
- train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
230
-
231
- return train_dataset
232
-
233
-
234
- # collation
235
-
236
- def collate_fn(batch):
237
- mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
238
- mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
239
- max_mel_length = mel_lengths.amax()
240
-
241
- padded_mel_specs = []
242
- for spec in mel_specs: # TODO. maybe records mask for attention here
243
- padding = (0, max_mel_length - spec.size(-1))
244
- padded_spec = F.pad(spec, padding, value = 0)
245
- padded_mel_specs.append(padded_spec)
246
-
247
- mel_specs = torch.stack(padded_mel_specs)
248
-
249
- text = [item['text'] for item in batch]
250
- text_lengths = torch.LongTensor([len(item) for item in text])
251
-
252
- return dict(
253
- mel = mel_specs,
254
- mel_lengths = mel_lengths,
255
- text = text,
256
- text_lengths = text_lengths,
257
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/ecapa_tdnn.py DELETED
@@ -1,268 +0,0 @@
1
- # just for speaker similarity evaluation, third-party code
2
-
3
- # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
- # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
-
6
- import os
7
- import torch
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
-
11
-
12
- ''' Res2Conv1d + BatchNorm1d + ReLU
13
- '''
14
-
15
- class Res2Conv1dReluBn(nn.Module):
16
- '''
17
- in_channels == out_channels == channels
18
- '''
19
-
20
- def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
- super().__init__()
22
- assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
- self.scale = scale
24
- self.width = channels // scale
25
- self.nums = scale if scale == 1 else scale - 1
26
-
27
- self.convs = []
28
- self.bns = []
29
- for i in range(self.nums):
30
- self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
- self.bns.append(nn.BatchNorm1d(self.width))
32
- self.convs = nn.ModuleList(self.convs)
33
- self.bns = nn.ModuleList(self.bns)
34
-
35
- def forward(self, x):
36
- out = []
37
- spx = torch.split(x, self.width, 1)
38
- for i in range(self.nums):
39
- if i == 0:
40
- sp = spx[i]
41
- else:
42
- sp = sp + spx[i]
43
- # Order: conv -> relu -> bn
44
- sp = self.convs[i](sp)
45
- sp = self.bns[i](F.relu(sp))
46
- out.append(sp)
47
- if self.scale != 1:
48
- out.append(spx[self.nums])
49
- out = torch.cat(out, dim=1)
50
-
51
- return out
52
-
53
-
54
- ''' Conv1d + BatchNorm1d + ReLU
55
- '''
56
-
57
- class Conv1dReluBn(nn.Module):
58
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
- super().__init__()
60
- self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
- self.bn = nn.BatchNorm1d(out_channels)
62
-
63
- def forward(self, x):
64
- return self.bn(F.relu(self.conv(x)))
65
-
66
-
67
- ''' The SE connection of 1D case.
68
- '''
69
-
70
- class SE_Connect(nn.Module):
71
- def __init__(self, channels, se_bottleneck_dim=128):
72
- super().__init__()
73
- self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
- self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
-
76
- def forward(self, x):
77
- out = x.mean(dim=2)
78
- out = F.relu(self.linear1(out))
79
- out = torch.sigmoid(self.linear2(out))
80
- out = x * out.unsqueeze(2)
81
-
82
- return out
83
-
84
-
85
- ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
- '''
87
-
88
- # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
- # return nn.Sequential(
90
- # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
- # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
- # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
- # SE_Connect(channels)
94
- # )
95
-
96
- class SE_Res2Block(nn.Module):
97
- def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
- super().__init__()
99
- self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
- self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
- self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
- self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
-
104
- self.shortcut = None
105
- if in_channels != out_channels:
106
- self.shortcut = nn.Conv1d(
107
- in_channels=in_channels,
108
- out_channels=out_channels,
109
- kernel_size=1,
110
- )
111
-
112
- def forward(self, x):
113
- residual = x
114
- if self.shortcut:
115
- residual = self.shortcut(x)
116
-
117
- x = self.Conv1dReluBn1(x)
118
- x = self.Res2Conv1dReluBn(x)
119
- x = self.Conv1dReluBn2(x)
120
- x = self.SE_Connect(x)
121
-
122
- return x + residual
123
-
124
-
125
- ''' Attentive weighted mean and standard deviation pooling.
126
- '''
127
-
128
- class AttentiveStatsPool(nn.Module):
129
- def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
- super().__init__()
131
- self.global_context_att = global_context_att
132
-
133
- # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
- if global_context_att:
135
- self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
- else:
137
- self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
- self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
-
140
- def forward(self, x):
141
-
142
- if self.global_context_att:
143
- context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
- context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
- x_in = torch.cat((x, context_mean, context_std), dim=1)
146
- else:
147
- x_in = x
148
-
149
- # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
- alpha = torch.tanh(self.linear1(x_in))
151
- # alpha = F.relu(self.linear1(x_in))
152
- alpha = torch.softmax(self.linear2(alpha), dim=2)
153
- mean = torch.sum(alpha * x, dim=2)
154
- residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
- std = torch.sqrt(residuals.clamp(min=1e-9))
156
- return torch.cat([mean, std], dim=1)
157
-
158
-
159
- class ECAPA_TDNN(nn.Module):
160
- def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
- feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
- super().__init__()
163
-
164
- self.feat_type = feat_type
165
- self.feature_selection = feature_selection
166
- self.update_extract = update_extract
167
- self.sr = sr
168
-
169
- torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
- try:
171
- local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
- self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
- except:
174
- self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
-
176
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
- self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
- if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
- self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
-
181
- self.feat_num = self.get_feat_num()
182
- self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
-
184
- if feat_type != 'fbank' and feat_type != 'mfcc':
185
- freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
- for name, param in self.feature_extract.named_parameters():
187
- for freeze_val in freeze_list:
188
- if freeze_val in name:
189
- param.requires_grad = False
190
- break
191
-
192
- if not self.update_extract:
193
- for param in self.feature_extract.parameters():
194
- param.requires_grad = False
195
-
196
- self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
- # self.channels = [channels] * 4 + [channels * 3]
198
- self.channels = [channels] * 4 + [1536]
199
-
200
- self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
- self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
- self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
- self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
-
205
- # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
- cat_channels = channels * 3
207
- self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
- self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
- self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
- self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
-
212
-
213
- def get_feat_num(self):
214
- self.feature_extract.eval()
215
- wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
- with torch.no_grad():
217
- features = self.feature_extract(wav)
218
- select_feature = features[self.feature_selection]
219
- if isinstance(select_feature, (list, tuple)):
220
- return len(select_feature)
221
- else:
222
- return 1
223
-
224
- def get_feat(self, x):
225
- if self.update_extract:
226
- x = self.feature_extract([sample for sample in x])
227
- else:
228
- with torch.no_grad():
229
- if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
- x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
- else:
232
- x = self.feature_extract([sample for sample in x])
233
-
234
- if self.feat_type == 'fbank':
235
- x = x.log()
236
-
237
- if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
- x = x[self.feature_selection]
239
- if isinstance(x, (list, tuple)):
240
- x = torch.stack(x, dim=0)
241
- else:
242
- x = x.unsqueeze(0)
243
- norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
- x = (norm_weights * x).sum(dim=0)
245
- x = torch.transpose(x, 1, 2) + 1e-6
246
-
247
- x = self.instance_norm(x)
248
- return x
249
-
250
- def forward(self, x):
251
- x = self.get_feat(x)
252
-
253
- out1 = self.layer1(x)
254
- out2 = self.layer2(out1)
255
- out3 = self.layer3(out2)
256
- out4 = self.layer4(out3)
257
-
258
- out = torch.cat([out2, out3, out4], dim=1)
259
- out = F.relu(self.conv(out))
260
- out = self.bn(self.pooling(out))
261
- out = self.linear(out)
262
-
263
- return out
264
-
265
-
266
- def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
- return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
- feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/modules.py DELETED
@@ -1,574 +0,0 @@
1
- """
2
- ein notation:
3
- b - batch
4
- n - sequence
5
- nt - text sequence
6
- nw - raw wave length
7
- d - dimension
8
- """
9
-
10
- from __future__ import annotations
11
- from typing import Optional
12
- import math
13
-
14
- import torch
15
- from torch import nn
16
- import torch.nn.functional as F
17
- import torchaudio
18
-
19
- from einops import rearrange
20
- from x_transformers.x_transformers import apply_rotary_pos_emb
21
-
22
-
23
- # raw wav to mel spec
24
-
25
- class MelSpec(nn.Module):
26
- def __init__(
27
- self,
28
- filter_length = 1024,
29
- hop_length = 256,
30
- win_length = 1024,
31
- n_mel_channels = 100,
32
- target_sample_rate = 24_000,
33
- normalize = False,
34
- power = 1,
35
- norm = None,
36
- center = True,
37
- ):
38
- super().__init__()
39
- self.n_mel_channels = n_mel_channels
40
-
41
- self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
- sample_rate = target_sample_rate,
43
- n_fft = filter_length,
44
- win_length = win_length,
45
- hop_length = hop_length,
46
- n_mels = n_mel_channels,
47
- power = power,
48
- center = center,
49
- normalized = normalize,
50
- norm = norm,
51
- )
52
-
53
- self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
-
55
- def forward(self, inp):
56
- if len(inp.shape) == 3:
57
- inp = rearrange(inp, 'b 1 nw -> b nw')
58
-
59
- assert len(inp.shape) == 2
60
-
61
- if self.dummy.device != inp.device:
62
- self.to(inp.device)
63
-
64
- mel = self.mel_stft(inp)
65
- mel = mel.clamp(min = 1e-5).log()
66
- return mel
67
-
68
-
69
- # sinusoidal position embedding
70
-
71
- class SinusPositionEmbedding(nn.Module):
72
- def __init__(self, dim):
73
- super().__init__()
74
- self.dim = dim
75
-
76
- def forward(self, x, scale=1000):
77
- device = x.device
78
- half_dim = self.dim // 2
79
- emb = math.log(10000) / (half_dim - 1)
80
- emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
- emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
- emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
- return emb
84
-
85
-
86
- # convolutional position embedding
87
-
88
- class ConvPositionEmbedding(nn.Module):
89
- def __init__(self, dim, kernel_size = 31, groups = 16):
90
- super().__init__()
91
- assert kernel_size % 2 != 0
92
- self.conv1d = nn.Sequential(
93
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
- nn.Mish(),
95
- nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
- nn.Mish(),
97
- )
98
-
99
- def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
- if mask is not None:
101
- mask = mask[..., None]
102
- x = x.masked_fill(~mask, 0.)
103
-
104
- x = rearrange(x, 'b n d -> b d n')
105
- x = self.conv1d(x)
106
- out = rearrange(x, 'b d n -> b n d')
107
-
108
- if mask is not None:
109
- out = out.masked_fill(~mask, 0.)
110
-
111
- return out
112
-
113
-
114
- # rotary positional embedding related
115
-
116
- def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
- # has some connection to NTK literature
119
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
- # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
- theta *= theta_rescale_factor ** (dim / (dim - 2))
122
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
- t = torch.arange(end, device=freqs.device) # type: ignore
124
- freqs = torch.outer(t, freqs).float() # type: ignore
125
- freqs_cos = torch.cos(freqs) # real part
126
- freqs_sin = torch.sin(freqs) # imaginary part
127
- return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
-
129
- def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
- # length = length if isinstance(length, int) else length.max()
131
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
- pos = start.unsqueeze(1) + (
133
- torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
- scale.unsqueeze(1)).long()
135
- # avoid extra long error.
136
- pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
- return pos
138
-
139
-
140
- # Global Response Normalization layer (Instance Normalization ?)
141
-
142
- class GRN(nn.Module):
143
- def __init__(self, dim):
144
- super().__init__()
145
- self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
- self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
-
148
- def forward(self, x):
149
- Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
- return self.gamma * (x * Nx) + self.beta + x
152
-
153
-
154
- # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
- # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
-
157
- class ConvNeXtV2Block(nn.Module):
158
- def __init__(
159
- self,
160
- dim: int,
161
- intermediate_dim: int,
162
- dilation: int = 1,
163
- ):
164
- super().__init__()
165
- padding = (dilation * (7 - 1)) // 2
166
- self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
- self.norm = nn.LayerNorm(dim, eps=1e-6)
168
- self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
- self.act = nn.GELU()
170
- self.grn = GRN(intermediate_dim)
171
- self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
-
173
- def forward(self, x: torch.Tensor) -> torch.Tensor:
174
- residual = x
175
- x = x.transpose(1, 2) # b n d -> b d n
176
- x = self.dwconv(x)
177
- x = x.transpose(1, 2) # b d n -> b n d
178
- x = self.norm(x)
179
- x = self.pwconv1(x)
180
- x = self.act(x)
181
- x = self.grn(x)
182
- x = self.pwconv2(x)
183
- return residual + x
184
-
185
-
186
- # AdaLayerNormZero
187
- # return with modulated x for attn input, and params for later mlp modulation
188
-
189
- class AdaLayerNormZero(nn.Module):
190
- def __init__(self, dim):
191
- super().__init__()
192
-
193
- self.silu = nn.SiLU()
194
- self.linear = nn.Linear(dim, dim * 6)
195
-
196
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
-
198
- def forward(self, x, emb = None):
199
- emb = self.linear(self.silu(emb))
200
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
-
202
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
-
205
-
206
- # AdaLayerNormZero for final layer
207
- # return only with modulated x for attn input, cuz no more mlp modulation
208
-
209
- class AdaLayerNormZero_Final(nn.Module):
210
- def __init__(self, dim):
211
- super().__init__()
212
-
213
- self.silu = nn.SiLU()
214
- self.linear = nn.Linear(dim, dim * 2)
215
-
216
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
-
218
- def forward(self, x, emb):
219
- emb = self.linear(self.silu(emb))
220
- scale, shift = torch.chunk(emb, 2, dim=1)
221
-
222
- x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
- return x
224
-
225
-
226
- # FeedForward
227
-
228
- class FeedForward(nn.Module):
229
- def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
- super().__init__()
231
- inner_dim = int(dim * mult)
232
- dim_out = dim_out if dim_out is not None else dim
233
-
234
- activation = nn.GELU(approximate=approximate)
235
- project_in = nn.Sequential(
236
- nn.Linear(dim, inner_dim),
237
- activation
238
- )
239
- self.ff = nn.Sequential(
240
- project_in,
241
- nn.Dropout(dropout),
242
- nn.Linear(inner_dim, dim_out)
243
- )
244
-
245
- def forward(self, x):
246
- return self.ff(x)
247
-
248
-
249
- # Attention with possible joint part
250
- # modified from diffusers/src/diffusers/models/attention_processor.py
251
-
252
- class Attention(nn.Module):
253
- def __init__(
254
- self,
255
- processor: JointAttnProcessor | AttnProcessor,
256
- dim: int,
257
- heads: int = 8,
258
- dim_head: int = 64,
259
- dropout: float = 0.0,
260
- context_dim: Optional[int] = None, # if not None -> joint attention
261
- context_pre_only = None,
262
- ):
263
- super().__init__()
264
-
265
- if not hasattr(F, "scaled_dot_product_attention"):
266
- raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
-
268
- self.processor = processor
269
-
270
- self.dim = dim
271
- self.heads = heads
272
- self.inner_dim = dim_head * heads
273
- self.dropout = dropout
274
-
275
- self.context_dim = context_dim
276
- self.context_pre_only = context_pre_only
277
-
278
- self.to_q = nn.Linear(dim, self.inner_dim)
279
- self.to_k = nn.Linear(dim, self.inner_dim)
280
- self.to_v = nn.Linear(dim, self.inner_dim)
281
-
282
- if self.context_dim is not None:
283
- self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
- self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
- if self.context_pre_only is not None:
286
- self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
-
288
- self.to_out = nn.ModuleList([])
289
- self.to_out.append(nn.Linear(self.inner_dim, dim))
290
- self.to_out.append(nn.Dropout(dropout))
291
-
292
- if self.context_pre_only is not None and not self.context_pre_only:
293
- self.to_out_c = nn.Linear(self.inner_dim, dim)
294
-
295
- def forward(
296
- self,
297
- x: float['b n d'], # noised input x
298
- c: float['b n d'] = None, # context c
299
- mask: bool['b n'] | None = None,
300
- rope = None, # rotary position embedding for x
301
- c_rope = None, # rotary position embedding for c
302
- ) -> torch.Tensor:
303
- if c is not None:
304
- return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
- else:
306
- return self.processor(self, x, mask = mask, rope = rope)
307
-
308
-
309
- # Attention processor
310
-
311
- class AttnProcessor:
312
- def __init__(self):
313
- pass
314
-
315
- def __call__(
316
- self,
317
- attn: Attention,
318
- x: float['b n d'], # noised input x
319
- mask: bool['b n'] | None = None,
320
- rope = None, # rotary position embedding
321
- ) -> torch.FloatTensor:
322
-
323
- batch_size = x.shape[0]
324
-
325
- # `sample` projections.
326
- query = attn.to_q(x)
327
- key = attn.to_k(x)
328
- value = attn.to_v(x)
329
-
330
- # apply rotary position embedding
331
- if rope is not None:
332
- freqs, xpos_scale = rope
333
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
-
335
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
-
338
- # attention
339
- inner_dim = key.shape[-1]
340
- head_dim = inner_dim // attn.heads
341
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
-
345
- # mask. e.g. inference got a batch with different target durations, mask out the padding
346
- if mask is not None:
347
- attn_mask = mask
348
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
- else:
351
- attn_mask = None
352
-
353
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
- x = x.to(query.dtype)
356
-
357
- # linear proj
358
- x = attn.to_out[0](x)
359
- # dropout
360
- x = attn.to_out[1](x)
361
-
362
- if mask is not None:
363
- mask = rearrange(mask, 'b n -> b n 1')
364
- x = x.masked_fill(~mask, 0.)
365
-
366
- return x
367
-
368
-
369
- # Joint Attention processor for MM-DiT
370
- # modified from diffusers/src/diffusers/models/attention_processor.py
371
-
372
- class JointAttnProcessor:
373
- def __init__(self):
374
- pass
375
-
376
- def __call__(
377
- self,
378
- attn: Attention,
379
- x: float['b n d'], # noised input x
380
- c: float['b nt d'] = None, # context c, here text
381
- mask: bool['b n'] | None = None,
382
- rope = None, # rotary position embedding for x
383
- c_rope = None, # rotary position embedding for c
384
- ) -> torch.FloatTensor:
385
- residual = x
386
-
387
- batch_size = c.shape[0]
388
-
389
- # `sample` projections.
390
- query = attn.to_q(x)
391
- key = attn.to_k(x)
392
- value = attn.to_v(x)
393
-
394
- # `context` projections.
395
- c_query = attn.to_q_c(c)
396
- c_key = attn.to_k_c(c)
397
- c_value = attn.to_v_c(c)
398
-
399
- # apply rope for context and noised input independently
400
- if rope is not None:
401
- freqs, xpos_scale = rope
402
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
- query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
- key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
- if c_rope is not None:
406
- freqs, xpos_scale = c_rope
407
- q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
- c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
- c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
-
411
- # attention
412
- query = torch.cat([query, c_query], dim=1)
413
- key = torch.cat([key, c_key], dim=1)
414
- value = torch.cat([value, c_value], dim=1)
415
-
416
- inner_dim = key.shape[-1]
417
- head_dim = inner_dim // attn.heads
418
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
-
422
- # mask. e.g. inference got a batch with different target durations, mask out the padding
423
- if mask is not None:
424
- attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
- attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
- attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
- else:
428
- attn_mask = None
429
-
430
- x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
- x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
- x = x.to(query.dtype)
433
-
434
- # Split the attention outputs.
435
- x, c = (
436
- x[:, :residual.shape[1]],
437
- x[:, residual.shape[1]:],
438
- )
439
-
440
- # linear proj
441
- x = attn.to_out[0](x)
442
- # dropout
443
- x = attn.to_out[1](x)
444
- if not attn.context_pre_only:
445
- c = attn.to_out_c(c)
446
-
447
- if mask is not None:
448
- mask = rearrange(mask, 'b n -> b n 1')
449
- x = x.masked_fill(~mask, 0.)
450
- # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
-
452
- return x, c
453
-
454
-
455
- # DiT Block
456
-
457
- class DiTBlock(nn.Module):
458
-
459
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
- super().__init__()
461
-
462
- self.attn_norm = AdaLayerNormZero(dim)
463
- self.attn = Attention(
464
- processor = AttnProcessor(),
465
- dim = dim,
466
- heads = heads,
467
- dim_head = dim_head,
468
- dropout = dropout,
469
- )
470
-
471
- self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
- self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
-
474
- def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
- # pre-norm & modulation for attention input
476
- norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
-
478
- # attention
479
- attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
-
481
- # process attention output for input x
482
- x = x + gate_msa.unsqueeze(1) * attn_output
483
-
484
- norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
- ff_output = self.ff(norm)
486
- x = x + gate_mlp.unsqueeze(1) * ff_output
487
-
488
- return x
489
-
490
-
491
- # MMDiT Block https://arxiv.org/abs/2403.03206
492
-
493
- class MMDiTBlock(nn.Module):
494
- r"""
495
- modified from diffusers/src/diffusers/models/attention.py
496
- notes.
497
- _c: context related. text, cond, etc. (left part in sd3 fig2.b)
498
- _x: noised input related. (right part)
499
- context_pre_only: last layer only do prenorm + modulation cuz no more ffn
500
- """
501
-
502
- def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
503
- super().__init__()
504
-
505
- self.context_pre_only = context_pre_only
506
-
507
- self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
508
- self.attn_norm_x = AdaLayerNormZero(dim)
509
- self.attn = Attention(
510
- processor = JointAttnProcessor(),
511
- dim = dim,
512
- heads = heads,
513
- dim_head = dim_head,
514
- dropout = dropout,
515
- context_dim = dim,
516
- context_pre_only = context_pre_only,
517
- )
518
-
519
- if not context_pre_only:
520
- self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
521
- self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
522
- else:
523
- self.ff_norm_c = None
524
- self.ff_c = None
525
- self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
526
- self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
527
-
528
- def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
529
- # pre-norm & modulation for attention input
530
- if self.context_pre_only:
531
- norm_c = self.attn_norm_c(c, t)
532
- else:
533
- norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
534
- norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
535
-
536
- # attention
537
- x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
538
-
539
- # process attention output for context c
540
- if self.context_pre_only:
541
- c = None
542
- else: # if not last layer
543
- c = c + c_gate_msa.unsqueeze(1) * c_attn_output
544
-
545
- norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
546
- c_ff_output = self.ff_c(norm_c)
547
- c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
548
-
549
- # process attention output for input x
550
- x = x + x_gate_msa.unsqueeze(1) * x_attn_output
551
-
552
- norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
553
- x_ff_output = self.ff_x(norm_x)
554
- x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
555
-
556
- return c, x
557
-
558
-
559
- # time step conditioning embedding
560
-
561
- class TimestepEmbedding(nn.Module):
562
- def __init__(self, dim, freq_embed_dim=256):
563
- super().__init__()
564
- self.time_embed = SinusPositionEmbedding(freq_embed_dim)
565
- self.time_mlp = nn.Sequential(
566
- nn.Linear(freq_embed_dim, dim),
567
- nn.SiLU(),
568
- nn.Linear(dim, dim)
569
- )
570
-
571
- def forward(self, timestep: float['b']):
572
- time_hidden = self.time_embed(timestep)
573
- time = self.time_mlp(time_hidden) # b d
574
- return time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/trainer.py DELETED
@@ -1,250 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import gc
5
- from tqdm import tqdm
6
- import wandb
7
-
8
- import torch
9
- from torch.optim import AdamW
10
- from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
- from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
-
13
- from einops import rearrange
14
-
15
- from accelerate import Accelerator
16
- from accelerate.utils import DistributedDataParallelKwargs
17
-
18
- from ema_pytorch import EMA
19
-
20
- from model import CFM
21
- from model.utils import exists, default
22
- from model.dataset import DynamicBatchSampler, collate_fn
23
-
24
-
25
- # trainer
26
-
27
- class Trainer:
28
- def __init__(
29
- self,
30
- model: CFM,
31
- epochs,
32
- learning_rate,
33
- num_warmup_updates = 20000,
34
- save_per_updates = 1000,
35
- checkpoint_path = None,
36
- batch_size = 32,
37
- batch_size_type: str = "sample",
38
- max_samples = 32,
39
- grad_accumulation_steps = 1,
40
- max_grad_norm = 1.0,
41
- noise_scheduler: str | None = None,
42
- duration_predictor: torch.nn.Module | None = None,
43
- wandb_project = "test_e2-tts",
44
- wandb_run_name = "test_run",
45
- wandb_resume_id: str = None,
46
- last_per_steps = None,
47
- accelerate_kwargs: dict = dict(),
48
- ema_kwargs: dict = dict()
49
- ):
50
-
51
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
-
53
- self.accelerator = Accelerator(
54
- log_with = "wandb",
55
- kwargs_handlers = [ddp_kwargs],
56
- gradient_accumulation_steps = grad_accumulation_steps,
57
- **accelerate_kwargs
58
- )
59
-
60
- if exists(wandb_resume_id):
61
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
- else:
63
- init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
- self.accelerator.init_trackers(
65
- project_name = wandb_project,
66
- init_kwargs=init_kwargs,
67
- config={"epochs": epochs,
68
- "learning_rate": learning_rate,
69
- "num_warmup_updates": num_warmup_updates,
70
- "batch_size": batch_size,
71
- "batch_size_type": batch_size_type,
72
- "max_samples": max_samples,
73
- "grad_accumulation_steps": grad_accumulation_steps,
74
- "max_grad_norm": max_grad_norm,
75
- "gpus": self.accelerator.num_processes,
76
- "noise_scheduler": noise_scheduler}
77
- )
78
-
79
- self.model = model
80
-
81
- if self.is_main:
82
- self.ema_model = EMA(
83
- model,
84
- include_online_model = False,
85
- **ema_kwargs
86
- )
87
-
88
- self.ema_model.to(self.accelerator.device)
89
-
90
- self.epochs = epochs
91
- self.num_warmup_updates = num_warmup_updates
92
- self.save_per_updates = save_per_updates
93
- self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
- self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
-
96
- self.batch_size = batch_size
97
- self.batch_size_type = batch_size_type
98
- self.max_samples = max_samples
99
- self.grad_accumulation_steps = grad_accumulation_steps
100
- self.max_grad_norm = max_grad_norm
101
-
102
- self.noise_scheduler = noise_scheduler
103
-
104
- self.duration_predictor = duration_predictor
105
-
106
- self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
- self.model, self.optimizer = self.accelerator.prepare(
108
- self.model, self.optimizer
109
- )
110
-
111
- @property
112
- def is_main(self):
113
- return self.accelerator.is_main_process
114
-
115
- def save_checkpoint(self, step, last=False):
116
- self.accelerator.wait_for_everyone()
117
- if self.is_main:
118
- checkpoint = dict(
119
- model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
- optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
- ema_model_state_dict = self.ema_model.state_dict(),
122
- scheduler_state_dict = self.scheduler.state_dict(),
123
- step = step
124
- )
125
- if not os.path.exists(self.checkpoint_path):
126
- os.makedirs(self.checkpoint_path)
127
- if last == True:
128
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
- print(f"Saved last checkpoint at step {step}")
130
- else:
131
- self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
-
133
- def load_checkpoint(self):
134
- if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
- return 0
136
-
137
- self.accelerator.wait_for_everyone()
138
- if "model_last.pt" in os.listdir(self.checkpoint_path):
139
- latest_checkpoint = "model_last.pt"
140
- else:
141
- latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
- # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
- checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
-
145
- if self.is_main:
146
- self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
-
148
- if 'step' in checkpoint:
149
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
- self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
- if self.scheduler:
152
- self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
- step = checkpoint['step']
154
- else:
155
- checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
- self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
- step = 0
158
-
159
- del checkpoint; gc.collect()
160
- return step
161
-
162
- def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
-
164
- if exists(resumable_with_seed):
165
- generator = torch.Generator()
166
- generator.manual_seed(resumable_with_seed)
167
- else:
168
- generator = None
169
-
170
- if self.batch_size_type == "sample":
171
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
- batch_size=self.batch_size, shuffle=True, generator=generator)
173
- elif self.batch_size_type == "frame":
174
- self.accelerator.even_batches = False
175
- sampler = SequentialSampler(train_dataset)
176
- batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
- train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
- batch_sampler=batch_sampler)
179
- else:
180
- raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
-
182
- # accelerator.prepare() dispatches batches to devices;
183
- # which means the length of dataloader calculated before, should consider the number of devices
184
- warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
- # otherwise by default with split_batches=False, warmup steps change with num_processes
186
- total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
- decay_steps = total_steps - warmup_steps
188
- warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
- decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
- self.scheduler = SequentialLR(self.optimizer,
191
- schedulers=[warmup_scheduler, decay_scheduler],
192
- milestones=[warmup_steps])
193
- train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
- start_step = self.load_checkpoint()
195
- global_step = start_step
196
-
197
- if exists(resumable_with_seed):
198
- orig_epoch_step = len(train_dataloader)
199
- skipped_epoch = int(start_step // orig_epoch_step)
200
- skipped_batch = start_step % orig_epoch_step
201
- skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
- else:
203
- skipped_epoch = 0
204
-
205
- for epoch in range(skipped_epoch, self.epochs):
206
- self.model.train()
207
- if exists(resumable_with_seed) and epoch == skipped_epoch:
208
- progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
- initial=skipped_batch, total=orig_epoch_step)
210
- else:
211
- progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
-
213
- for batch in progress_bar:
214
- with self.accelerator.accumulate(self.model):
215
- text_inputs = batch['text']
216
- mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
- mel_lengths = batch["mel_lengths"]
218
-
219
- # TODO. add duration predictor training
220
- if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
- dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
- self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
-
224
- loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
- self.accelerator.backward(loss)
226
-
227
- if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
- self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
-
230
- self.optimizer.step()
231
- self.scheduler.step()
232
- self.optimizer.zero_grad()
233
-
234
- if self.is_main:
235
- self.ema_model.update()
236
-
237
- global_step += 1
238
-
239
- if self.accelerator.is_local_main_process:
240
- self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
-
242
- progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
-
244
- if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
- self.save_checkpoint(global_step)
246
-
247
- if global_step % self.last_per_steps == 0:
248
- self.save_checkpoint(global_step, last=True)
249
-
250
- self.accelerator.end_training()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model/utils.py DELETED
@@ -1,580 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import os
4
- import re
5
- import math
6
- import random
7
- import string
8
- from tqdm import tqdm
9
- from collections import defaultdict
10
-
11
- import matplotlib
12
- matplotlib.use("Agg")
13
- import matplotlib.pylab as plt
14
-
15
- import torch
16
- import torch.nn.functional as F
17
- from torch.nn.utils.rnn import pad_sequence
18
- import torchaudio
19
-
20
- import einx
21
- from einops import rearrange, reduce
22
-
23
- import jieba
24
- from pypinyin import lazy_pinyin, Style
25
-
26
- from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
- from model.modules import MelSpec
28
-
29
-
30
- # seed everything
31
-
32
- def seed_everything(seed = 0):
33
- random.seed(seed)
34
- os.environ['PYTHONHASHSEED'] = str(seed)
35
- torch.manual_seed(seed)
36
- torch.cuda.manual_seed(seed)
37
- torch.cuda.manual_seed_all(seed)
38
- torch.backends.cudnn.deterministic = True
39
- torch.backends.cudnn.benchmark = False
40
-
41
- # helpers
42
-
43
- def exists(v):
44
- return v is not None
45
-
46
- def default(v, d):
47
- return v if exists(v) else d
48
-
49
- # tensor helpers
50
-
51
- def lens_to_mask(
52
- t: int['b'],
53
- length: int | None = None
54
- ) -> bool['b n']:
55
-
56
- if not exists(length):
57
- length = t.amax()
58
-
59
- seq = torch.arange(length, device = t.device)
60
- return einx.less('n, b -> b n', seq, t)
61
-
62
- def mask_from_start_end_indices(
63
- seq_len: int['b'],
64
- start: int['b'],
65
- end: int['b']
66
- ):
67
- max_seq_len = seq_len.max().item()
68
- seq = torch.arange(max_seq_len, device = start.device).long()
69
- return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
-
71
- def mask_from_frac_lengths(
72
- seq_len: int['b'],
73
- frac_lengths: float['b']
74
- ):
75
- lengths = (frac_lengths * seq_len).long()
76
- max_start = seq_len - lengths
77
-
78
- rand = torch.rand_like(frac_lengths)
79
- start = (max_start * rand).long().clamp(min = 0)
80
- end = start + lengths
81
-
82
- return mask_from_start_end_indices(seq_len, start, end)
83
-
84
- def maybe_masked_mean(
85
- t: float['b n d'],
86
- mask: bool['b n'] = None
87
- ) -> float['b d']:
88
-
89
- if not exists(mask):
90
- return t.mean(dim = 1)
91
-
92
- t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
- num = reduce(t, 'b n d -> b d', 'sum')
94
- den = reduce(mask.float(), 'b n -> b', 'sum')
95
-
96
- return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
-
98
-
99
- # simple utf-8 tokenizer, since paper went character based
100
- def list_str_to_tensor(
101
- text: list[str],
102
- padding_value = -1
103
- ) -> int['b nt']:
104
- list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
- text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
- return text
107
-
108
- # char tokenizer, based on custom dataset's extracted .txt file
109
- def list_str_to_idx(
110
- text: list[str] | list[list[str]],
111
- vocab_char_map: dict[str, int], # {char: idx}
112
- padding_value = -1
113
- ) -> int['b nt']:
114
- list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
- text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
- return text
117
-
118
-
119
- # Get tokenizer
120
-
121
- def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
- '''
123
- tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
- - "char" for char-wise tokenizer, need .txt vocab_file
125
- - "byte" for utf-8 tokenizer
126
- - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
- vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
- - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
- - if use "byte", set to 256 (unicode byte range)
130
- '''
131
- if tokenizer in ["pinyin", "char"]:
132
- with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
- vocab_char_map = {}
134
- for i, char in enumerate(f):
135
- vocab_char_map[char[:-1]] = i
136
- vocab_size = len(vocab_char_map)
137
- assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
138
-
139
- elif tokenizer == "byte":
140
- vocab_char_map = None
141
- vocab_size = 256
142
- elif tokenizer == "custom":
143
- with open (dataset_name, "r", encoding="utf-8") as f:
144
- vocab_char_map = {}
145
- for i, char in enumerate(f):
146
- vocab_char_map[char[:-1]] = i
147
- vocab_size = len(vocab_char_map)
148
-
149
- return vocab_char_map, vocab_size
150
-
151
-
152
- # convert char to pinyin
153
-
154
- def convert_char_to_pinyin(text_list, polyphone = True):
155
- final_text_list = []
156
- god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
- custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
158
- for text in text_list:
159
- char_list = []
160
- text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
- text = text.translate(custom_trans)
162
- for seg in jieba.cut(text):
163
- seg_byte_len = len(bytes(seg, 'UTF-8'))
164
- if seg_byte_len == len(seg): # if pure alphabets and symbols
165
- if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
- char_list.append(" ")
167
- char_list.extend(seg)
168
- elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
169
- seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
170
- for c in seg:
171
- if c not in "。,、;:?!《》【】—…":
172
- char_list.append(" ")
173
- char_list.append(c)
174
- else: # if mixed chinese characters, alphabets and symbols
175
- for c in seg:
176
- if ord(c) < 256:
177
- char_list.extend(c)
178
- else:
179
- if c not in "。,、;:?!《》【】—…":
180
- char_list.append(" ")
181
- char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
182
- else: # if is zh punc
183
- char_list.append(c)
184
- final_text_list.append(char_list)
185
-
186
- return final_text_list
187
-
188
-
189
- # save spectrogram
190
- def save_spectrogram(spectrogram, path):
191
- plt.figure(figsize=(12, 4))
192
- plt.imshow(spectrogram, origin='lower', aspect='auto')
193
- plt.colorbar()
194
- plt.savefig(path)
195
- plt.close()
196
-
197
-
198
- # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
- def get_seedtts_testset_metainfo(metalst):
200
- f = open(metalst); lines = f.readlines(); f.close()
201
- metainfo = []
202
- for line in lines:
203
- if len(line.strip().split('|')) == 5:
204
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
- elif len(line.strip().split('|')) == 4:
206
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
- gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
- if not os.path.isabs(prompt_wav):
209
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
- metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
- return metainfo
212
-
213
-
214
- # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
- def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
- f = open(metalst); lines = f.readlines(); f.close()
217
- metainfo = []
218
- for line in lines:
219
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
-
221
- # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
-
225
- # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
-
229
- metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
-
231
- return metainfo
232
-
233
-
234
- # padded to max length mel batch
235
- def padded_mel_batch(ref_mels):
236
- max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
- padded_ref_mels = []
238
- for mel in ref_mels:
239
- padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
- padded_ref_mels.append(padded_ref_mel)
241
- padded_ref_mels = torch.stack(padded_ref_mels)
242
- padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
- return padded_ref_mels
244
-
245
-
246
- # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
-
248
- def get_inference_prompt(
249
- metainfo,
250
- speed = 1., tokenizer = "pinyin", polyphone = True,
251
- target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
- use_truth_duration = False,
253
- infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
- ):
255
- prompts_all = []
256
-
257
- min_tokens = min_secs * target_sample_rate // hop_length
258
- max_tokens = max_secs * target_sample_rate // hop_length
259
-
260
- batch_accum = [0] * num_buckets
261
- utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
- ([[] for _ in range(num_buckets)] for _ in range(6))
263
-
264
- mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
-
266
- for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
-
268
- # Audio
269
- ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
- ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
- if ref_rms < target_rms:
272
- ref_audio = ref_audio * target_rms / ref_rms
273
- assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
- if ref_sr != target_sample_rate:
275
- resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
- ref_audio = resampler(ref_audio)
277
-
278
- # Text
279
- if len(prompt_text[-1].encode('utf-8')) == 1:
280
- prompt_text = prompt_text + " "
281
- text = [prompt_text + gt_text]
282
- if tokenizer == "pinyin":
283
- text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
- else:
285
- text_list = text
286
-
287
- # Duration, mel frame length
288
- ref_mel_len = ref_audio.shape[-1] // hop_length
289
- if use_truth_duration:
290
- gt_audio, gt_sr = torchaudio.load(gt_wav)
291
- if gt_sr != target_sample_rate:
292
- resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
- gt_audio = resampler(gt_audio)
294
- total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
-
296
- # # test vocoder resynthesis
297
- # ref_audio = gt_audio
298
- else:
299
- zh_pause_punc = r"。,、;:?!"
300
- ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
- gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
- total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
-
304
- # to mel spectrogram
305
- ref_mel = mel_spectrogram(ref_audio)
306
- ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
-
308
- # deal with batch
309
- assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
- assert min_tokens <= total_mel_len <= max_tokens, \
311
- f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
- bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
-
314
- utts[bucket_i].append(utt)
315
- ref_rms_list[bucket_i].append(ref_rms)
316
- ref_mels[bucket_i].append(ref_mel)
317
- ref_mel_lens[bucket_i].append(ref_mel_len)
318
- total_mel_lens[bucket_i].append(total_mel_len)
319
- final_text_list[bucket_i].extend(text_list)
320
-
321
- batch_accum[bucket_i] += total_mel_len
322
-
323
- if batch_accum[bucket_i] >= infer_batch_size:
324
- # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
- prompts_all.append((
326
- utts[bucket_i],
327
- ref_rms_list[bucket_i],
328
- padded_mel_batch(ref_mels[bucket_i]),
329
- ref_mel_lens[bucket_i],
330
- total_mel_lens[bucket_i],
331
- final_text_list[bucket_i]
332
- ))
333
- batch_accum[bucket_i] = 0
334
- utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
-
336
- # add residual
337
- for bucket_i, bucket_frames in enumerate(batch_accum):
338
- if bucket_frames > 0:
339
- prompts_all.append((
340
- utts[bucket_i],
341
- ref_rms_list[bucket_i],
342
- padded_mel_batch(ref_mels[bucket_i]),
343
- ref_mel_lens[bucket_i],
344
- total_mel_lens[bucket_i],
345
- final_text_list[bucket_i]
346
- ))
347
- # not only leave easy work for last workers
348
- random.seed(666)
349
- random.shuffle(prompts_all)
350
-
351
- return prompts_all
352
-
353
-
354
- # get wav_res_ref_text of seed-tts test metalst
355
- # https://github.com/BytedanceSpeech/seed-tts-eval
356
-
357
- def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
- f = open(metalst)
359
- lines = f.readlines()
360
- f.close()
361
-
362
- test_set_ = []
363
- for line in tqdm(lines):
364
- if len(line.strip().split('|')) == 5:
365
- utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
- elif len(line.strip().split('|')) == 4:
367
- utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
-
369
- if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
- continue
371
- gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
- if not os.path.isabs(prompt_wav):
373
- prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
-
375
- test_set_.append((gen_wav, prompt_wav, gt_text))
376
-
377
- num_jobs = len(gpus)
378
- if num_jobs == 1:
379
- return [(gpus[0], test_set_)]
380
-
381
- wav_per_job = len(test_set_) // num_jobs + 1
382
- test_set = []
383
- for i in range(num_jobs):
384
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
-
386
- return test_set
387
-
388
-
389
- # get librispeech test-clean cross sentence test
390
-
391
- def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
- f = open(metalst)
393
- lines = f.readlines()
394
- f.close()
395
-
396
- test_set_ = []
397
- for line in tqdm(lines):
398
- ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
-
400
- if eval_ground_truth:
401
- gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
- gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
- else:
404
- if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
- raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
- gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
-
408
- ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
- ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
-
411
- test_set_.append((gen_wav, ref_wav, gen_txt))
412
-
413
- num_jobs = len(gpus)
414
- if num_jobs == 1:
415
- return [(gpus[0], test_set_)]
416
-
417
- wav_per_job = len(test_set_) // num_jobs + 1
418
- test_set = []
419
- for i in range(num_jobs):
420
- test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
-
422
- return test_set
423
-
424
-
425
- # load asr model
426
-
427
- def load_asr_model(lang, ckpt_dir = ""):
428
- if lang == "zh":
429
- from funasr import AutoModel
430
- model = AutoModel(
431
- model = os.path.join(ckpt_dir, "paraformer-zh"),
432
- # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
- # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
- # spk_model = os.path.join(ckpt_dir, "cam++"),
435
- disable_update=True,
436
- ) # following seed-tts setting
437
- elif lang == "en":
438
- from faster_whisper import WhisperModel
439
- model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
- model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
- return model
442
-
443
-
444
- # WER Evaluation, the way Seed-TTS does
445
-
446
- def run_asr_wer(args):
447
- rank, lang, test_set, ckpt_dir = args
448
-
449
- if lang == "zh":
450
- import zhconv
451
- torch.cuda.set_device(rank)
452
- elif lang == "en":
453
- os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
- else:
455
- raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
-
457
- asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
-
459
- from zhon.hanzi import punctuation
460
- punctuation_all = punctuation + string.punctuation
461
- wers = []
462
-
463
- from jiwer import compute_measures
464
- for gen_wav, prompt_wav, truth in tqdm(test_set):
465
- if lang == "zh":
466
- res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
- hypo = res[0]["text"]
468
- hypo = zhconv.convert(hypo, 'zh-cn')
469
- elif lang == "en":
470
- segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
- hypo = ''
472
- for segment in segments:
473
- hypo = hypo + ' ' + segment.text
474
-
475
- # raw_truth = truth
476
- # raw_hypo = hypo
477
-
478
- for x in punctuation_all:
479
- truth = truth.replace(x, '')
480
- hypo = hypo.replace(x, '')
481
-
482
- truth = truth.replace(' ', ' ')
483
- hypo = hypo.replace(' ', ' ')
484
-
485
- if lang == "zh":
486
- truth = " ".join([x for x in truth])
487
- hypo = " ".join([x for x in hypo])
488
- elif lang == "en":
489
- truth = truth.lower()
490
- hypo = hypo.lower()
491
-
492
- measures = compute_measures(truth, hypo)
493
- wer = measures["wer"]
494
-
495
- # ref_list = truth.split(" ")
496
- # subs = measures["substitutions"] / len(ref_list)
497
- # dele = measures["deletions"] / len(ref_list)
498
- # inse = measures["insertions"] / len(ref_list)
499
-
500
- wers.append(wer)
501
-
502
- return wers
503
-
504
-
505
- # SIM Evaluation
506
-
507
- def run_sim(args):
508
- rank, test_set, ckpt_dir = args
509
- device = f"cuda:{rank}"
510
-
511
- model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
- state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
- model.load_state_dict(state_dict['model'], strict=False)
514
-
515
- use_gpu=True if torch.cuda.is_available() else False
516
- if use_gpu:
517
- model = model.cuda(device)
518
- model.eval()
519
-
520
- sim_list = []
521
- for wav1, wav2, truth in tqdm(test_set):
522
-
523
- wav1, sr1 = torchaudio.load(wav1)
524
- wav2, sr2 = torchaudio.load(wav2)
525
-
526
- resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
- resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
- wav1 = resample1(wav1)
529
- wav2 = resample2(wav2)
530
-
531
- if use_gpu:
532
- wav1 = wav1.cuda(device)
533
- wav2 = wav2.cuda(device)
534
- with torch.no_grad():
535
- emb1 = model(wav1)
536
- emb2 = model(wav2)
537
-
538
- sim = F.cosine_similarity(emb1, emb2)[0].item()
539
- # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
- sim_list.append(sim)
541
-
542
- return sim_list
543
-
544
-
545
- # filter func for dirty data with many repetitions
546
-
547
- def repetition_found(text, length = 2, tolerance = 10):
548
- pattern_count = defaultdict(int)
549
- for i in range(len(text) - length + 1):
550
- pattern = text[i:i + length]
551
- pattern_count[pattern] += 1
552
- for pattern, count in pattern_count.items():
553
- if count > tolerance:
554
- return True
555
- return False
556
-
557
-
558
- # load model checkpoint for inference
559
-
560
- def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
- from ema_pytorch import EMA
562
-
563
- ckpt_type = ckpt_path.split(".")[-1]
564
- if ckpt_type == "safetensors":
565
- from safetensors.torch import load_file
566
- checkpoint = load_file(ckpt_path, device=device)
567
- else:
568
- checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
-
570
- if use_ema == True:
571
- ema_model = EMA(model, include_online_model = False).to(device)
572
- if ckpt_type == "safetensors":
573
- ema_model.load_state_dict(checkpoint)
574
- else:
575
- ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
- ema_model.copy_params_from_ema_to_model()
577
- else:
578
- model.load_state_dict(checkpoint['model_state_dict'])
579
-
580
- return model