Spaces:
Running
on
Zero
Running
on
Zero
Staticaliza
commited on
Commit
•
1d1273d
1
Parent(s):
fcf7ece
Delete model
Browse files- model/__init__.py +0 -7
- model/backbones/dit.py +0 -158
- model/backbones/mmdit.py +0 -136
- model/backbones/unett.py +0 -201
- model/cfm.py +0 -279
- model/dataset.py +0 -257
- model/ecapa_tdnn.py +0 -268
- model/modules.py +0 -574
- model/trainer.py +0 -250
- model/utils.py +0 -580
model/__init__.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
from model.cfm import CFM
|
2 |
-
|
3 |
-
from model.backbones.unett import UNetT
|
4 |
-
from model.backbones.dit import DiT
|
5 |
-
from model.backbones.mmdit import MMDiT
|
6 |
-
|
7 |
-
from model.trainer import Trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/backbones/dit.py
DELETED
@@ -1,158 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ein notation:
|
3 |
-
b - batch
|
4 |
-
n - sequence
|
5 |
-
nt - text sequence
|
6 |
-
nw - raw wave length
|
7 |
-
d - dimension
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import annotations
|
11 |
-
|
12 |
-
import torch
|
13 |
-
from torch import nn
|
14 |
-
import torch.nn.functional as F
|
15 |
-
|
16 |
-
from einops import repeat
|
17 |
-
|
18 |
-
from x_transformers.x_transformers import RotaryEmbedding
|
19 |
-
|
20 |
-
from model.modules import (
|
21 |
-
TimestepEmbedding,
|
22 |
-
ConvNeXtV2Block,
|
23 |
-
ConvPositionEmbedding,
|
24 |
-
DiTBlock,
|
25 |
-
AdaLayerNormZero_Final,
|
26 |
-
precompute_freqs_cis, get_pos_embed_indices,
|
27 |
-
)
|
28 |
-
|
29 |
-
|
30 |
-
# Text embedding
|
31 |
-
|
32 |
-
class TextEmbedding(nn.Module):
|
33 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
34 |
-
super().__init__()
|
35 |
-
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
36 |
-
|
37 |
-
if conv_layers > 0:
|
38 |
-
self.extra_modeling = True
|
39 |
-
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
40 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
41 |
-
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
42 |
-
else:
|
43 |
-
self.extra_modeling = False
|
44 |
-
|
45 |
-
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
46 |
-
batch, text_len = text.shape[0], text.shape[1]
|
47 |
-
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
48 |
-
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
49 |
-
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
50 |
-
|
51 |
-
if drop_text: # cfg for text
|
52 |
-
text = torch.zeros_like(text)
|
53 |
-
|
54 |
-
text = self.text_embed(text) # b n -> b n d
|
55 |
-
|
56 |
-
# possible extra modeling
|
57 |
-
if self.extra_modeling:
|
58 |
-
# sinus pos emb
|
59 |
-
batch_start = torch.zeros((batch,), dtype=torch.long)
|
60 |
-
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
61 |
-
text_pos_embed = self.freqs_cis[pos_idx]
|
62 |
-
text = text + text_pos_embed
|
63 |
-
|
64 |
-
# convnextv2 blocks
|
65 |
-
text = self.text_blocks(text)
|
66 |
-
|
67 |
-
return text
|
68 |
-
|
69 |
-
|
70 |
-
# noised input audio and context mixing embedding
|
71 |
-
|
72 |
-
class InputEmbedding(nn.Module):
|
73 |
-
def __init__(self, mel_dim, text_dim, out_dim):
|
74 |
-
super().__init__()
|
75 |
-
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
76 |
-
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
77 |
-
|
78 |
-
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
79 |
-
if drop_audio_cond: # cfg for cond audio
|
80 |
-
cond = torch.zeros_like(cond)
|
81 |
-
|
82 |
-
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
83 |
-
x = self.conv_pos_embed(x) + x
|
84 |
-
return x
|
85 |
-
|
86 |
-
|
87 |
-
# Transformer backbone using DiT blocks
|
88 |
-
|
89 |
-
class DiT(nn.Module):
|
90 |
-
def __init__(self, *,
|
91 |
-
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
92 |
-
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
93 |
-
long_skip_connection = False,
|
94 |
-
):
|
95 |
-
super().__init__()
|
96 |
-
|
97 |
-
self.time_embed = TimestepEmbedding(dim)
|
98 |
-
if text_dim is None:
|
99 |
-
text_dim = mel_dim
|
100 |
-
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
101 |
-
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
102 |
-
|
103 |
-
self.rotary_embed = RotaryEmbedding(dim_head)
|
104 |
-
|
105 |
-
self.dim = dim
|
106 |
-
self.depth = depth
|
107 |
-
|
108 |
-
self.transformer_blocks = nn.ModuleList(
|
109 |
-
[
|
110 |
-
DiTBlock(
|
111 |
-
dim = dim,
|
112 |
-
heads = heads,
|
113 |
-
dim_head = dim_head,
|
114 |
-
ff_mult = ff_mult,
|
115 |
-
dropout = dropout
|
116 |
-
)
|
117 |
-
for _ in range(depth)
|
118 |
-
]
|
119 |
-
)
|
120 |
-
self.long_skip_connection = nn.Linear(dim * 2, dim, bias = False) if long_skip_connection else None
|
121 |
-
|
122 |
-
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
123 |
-
self.proj_out = nn.Linear(dim, mel_dim)
|
124 |
-
|
125 |
-
def forward(
|
126 |
-
self,
|
127 |
-
x: float['b n d'], # nosied input audio
|
128 |
-
cond: float['b n d'], # masked cond audio
|
129 |
-
text: int['b nt'], # text
|
130 |
-
time: float['b'] | float[''], # time step
|
131 |
-
drop_audio_cond, # cfg for cond audio
|
132 |
-
drop_text, # cfg for text
|
133 |
-
mask: bool['b n'] | None = None,
|
134 |
-
):
|
135 |
-
batch, seq_len = x.shape[0], x.shape[1]
|
136 |
-
if time.ndim == 0:
|
137 |
-
time = repeat(time, ' -> b', b = batch)
|
138 |
-
|
139 |
-
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
140 |
-
t = self.time_embed(time)
|
141 |
-
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
142 |
-
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
143 |
-
|
144 |
-
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
145 |
-
|
146 |
-
if self.long_skip_connection is not None:
|
147 |
-
residual = x
|
148 |
-
|
149 |
-
for block in self.transformer_blocks:
|
150 |
-
x = block(x, t, mask = mask, rope = rope)
|
151 |
-
|
152 |
-
if self.long_skip_connection is not None:
|
153 |
-
x = self.long_skip_connection(torch.cat((x, residual), dim = -1))
|
154 |
-
|
155 |
-
x = self.norm_out(x, t)
|
156 |
-
output = self.proj_out(x)
|
157 |
-
|
158 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/backbones/mmdit.py
DELETED
@@ -1,136 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ein notation:
|
3 |
-
b - batch
|
4 |
-
n - sequence
|
5 |
-
nt - text sequence
|
6 |
-
nw - raw wave length
|
7 |
-
d - dimension
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import annotations
|
11 |
-
|
12 |
-
import torch
|
13 |
-
from torch import nn
|
14 |
-
|
15 |
-
from einops import repeat
|
16 |
-
|
17 |
-
from x_transformers.x_transformers import RotaryEmbedding
|
18 |
-
|
19 |
-
from model.modules import (
|
20 |
-
TimestepEmbedding,
|
21 |
-
ConvPositionEmbedding,
|
22 |
-
MMDiTBlock,
|
23 |
-
AdaLayerNormZero_Final,
|
24 |
-
precompute_freqs_cis, get_pos_embed_indices,
|
25 |
-
)
|
26 |
-
|
27 |
-
|
28 |
-
# text embedding
|
29 |
-
|
30 |
-
class TextEmbedding(nn.Module):
|
31 |
-
def __init__(self, out_dim, text_num_embeds):
|
32 |
-
super().__init__()
|
33 |
-
self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
|
34 |
-
|
35 |
-
self.precompute_max_pos = 1024
|
36 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
|
37 |
-
|
38 |
-
def forward(self, text: int['b nt'], drop_text = False) -> int['b nt d']:
|
39 |
-
text = text + 1
|
40 |
-
if drop_text:
|
41 |
-
text = torch.zeros_like(text)
|
42 |
-
text = self.text_embed(text)
|
43 |
-
|
44 |
-
# sinus pos emb
|
45 |
-
batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
|
46 |
-
batch_text_len = text.shape[1]
|
47 |
-
pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
|
48 |
-
text_pos_embed = self.freqs_cis[pos_idx]
|
49 |
-
|
50 |
-
text = text + text_pos_embed
|
51 |
-
|
52 |
-
return text
|
53 |
-
|
54 |
-
|
55 |
-
# noised input & masked cond audio embedding
|
56 |
-
|
57 |
-
class AudioEmbedding(nn.Module):
|
58 |
-
def __init__(self, in_dim, out_dim):
|
59 |
-
super().__init__()
|
60 |
-
self.linear = nn.Linear(2 * in_dim, out_dim)
|
61 |
-
self.conv_pos_embed = ConvPositionEmbedding(out_dim)
|
62 |
-
|
63 |
-
def forward(self, x: float['b n d'], cond: float['b n d'], drop_audio_cond = False):
|
64 |
-
if drop_audio_cond:
|
65 |
-
cond = torch.zeros_like(cond)
|
66 |
-
x = torch.cat((x, cond), dim = -1)
|
67 |
-
x = self.linear(x)
|
68 |
-
x = self.conv_pos_embed(x) + x
|
69 |
-
return x
|
70 |
-
|
71 |
-
|
72 |
-
# Transformer backbone using MM-DiT blocks
|
73 |
-
|
74 |
-
class MMDiT(nn.Module):
|
75 |
-
def __init__(self, *,
|
76 |
-
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
77 |
-
text_num_embeds = 256, mel_dim = 100,
|
78 |
-
):
|
79 |
-
super().__init__()
|
80 |
-
|
81 |
-
self.time_embed = TimestepEmbedding(dim)
|
82 |
-
self.text_embed = TextEmbedding(dim, text_num_embeds)
|
83 |
-
self.audio_embed = AudioEmbedding(mel_dim, dim)
|
84 |
-
|
85 |
-
self.rotary_embed = RotaryEmbedding(dim_head)
|
86 |
-
|
87 |
-
self.dim = dim
|
88 |
-
self.depth = depth
|
89 |
-
|
90 |
-
self.transformer_blocks = nn.ModuleList(
|
91 |
-
[
|
92 |
-
MMDiTBlock(
|
93 |
-
dim = dim,
|
94 |
-
heads = heads,
|
95 |
-
dim_head = dim_head,
|
96 |
-
dropout = dropout,
|
97 |
-
ff_mult = ff_mult,
|
98 |
-
context_pre_only = i == depth - 1,
|
99 |
-
)
|
100 |
-
for i in range(depth)
|
101 |
-
]
|
102 |
-
)
|
103 |
-
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
104 |
-
self.proj_out = nn.Linear(dim, mel_dim)
|
105 |
-
|
106 |
-
def forward(
|
107 |
-
self,
|
108 |
-
x: float['b n d'], # nosied input audio
|
109 |
-
cond: float['b n d'], # masked cond audio
|
110 |
-
text: int['b nt'], # text
|
111 |
-
time: float['b'] | float[''], # time step
|
112 |
-
drop_audio_cond, # cfg for cond audio
|
113 |
-
drop_text, # cfg for text
|
114 |
-
mask: bool['b n'] | None = None,
|
115 |
-
):
|
116 |
-
batch = x.shape[0]
|
117 |
-
if time.ndim == 0:
|
118 |
-
time = repeat(time, ' -> b', b = batch)
|
119 |
-
|
120 |
-
# t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
|
121 |
-
t = self.time_embed(time)
|
122 |
-
c = self.text_embed(text, drop_text = drop_text)
|
123 |
-
x = self.audio_embed(x, cond, drop_audio_cond = drop_audio_cond)
|
124 |
-
|
125 |
-
seq_len = x.shape[1]
|
126 |
-
text_len = text.shape[1]
|
127 |
-
rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
|
128 |
-
rope_text = self.rotary_embed.forward_from_seq_len(text_len)
|
129 |
-
|
130 |
-
for block in self.transformer_blocks:
|
131 |
-
c, x = block(x, c, t, mask = mask, rope = rope_audio, c_rope = rope_text)
|
132 |
-
|
133 |
-
x = self.norm_out(x, t)
|
134 |
-
output = self.proj_out(x)
|
135 |
-
|
136 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/backbones/unett.py
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ein notation:
|
3 |
-
b - batch
|
4 |
-
n - sequence
|
5 |
-
nt - text sequence
|
6 |
-
nw - raw wave length
|
7 |
-
d - dimension
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import annotations
|
11 |
-
from typing import Literal
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch import nn
|
15 |
-
import torch.nn.functional as F
|
16 |
-
|
17 |
-
from einops import repeat, pack, unpack
|
18 |
-
|
19 |
-
from x_transformers import RMSNorm
|
20 |
-
from x_transformers.x_transformers import RotaryEmbedding
|
21 |
-
|
22 |
-
from model.modules import (
|
23 |
-
TimestepEmbedding,
|
24 |
-
ConvNeXtV2Block,
|
25 |
-
ConvPositionEmbedding,
|
26 |
-
Attention,
|
27 |
-
AttnProcessor,
|
28 |
-
FeedForward,
|
29 |
-
precompute_freqs_cis, get_pos_embed_indices,
|
30 |
-
)
|
31 |
-
|
32 |
-
|
33 |
-
# Text embedding
|
34 |
-
|
35 |
-
class TextEmbedding(nn.Module):
|
36 |
-
def __init__(self, text_num_embeds, text_dim, conv_layers = 0, conv_mult = 2):
|
37 |
-
super().__init__()
|
38 |
-
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
39 |
-
|
40 |
-
if conv_layers > 0:
|
41 |
-
self.extra_modeling = True
|
42 |
-
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
43 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
44 |
-
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
|
45 |
-
else:
|
46 |
-
self.extra_modeling = False
|
47 |
-
|
48 |
-
def forward(self, text: int['b nt'], seq_len, drop_text = False):
|
49 |
-
batch, text_len = text.shape[0], text.shape[1]
|
50 |
-
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
51 |
-
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
52 |
-
text = F.pad(text, (0, seq_len - text_len), value = 0)
|
53 |
-
|
54 |
-
if drop_text: # cfg for text
|
55 |
-
text = torch.zeros_like(text)
|
56 |
-
|
57 |
-
text = self.text_embed(text) # b n -> b n d
|
58 |
-
|
59 |
-
# possible extra modeling
|
60 |
-
if self.extra_modeling:
|
61 |
-
# sinus pos emb
|
62 |
-
batch_start = torch.zeros((batch,), dtype=torch.long)
|
63 |
-
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
64 |
-
text_pos_embed = self.freqs_cis[pos_idx]
|
65 |
-
text = text + text_pos_embed
|
66 |
-
|
67 |
-
# convnextv2 blocks
|
68 |
-
text = self.text_blocks(text)
|
69 |
-
|
70 |
-
return text
|
71 |
-
|
72 |
-
|
73 |
-
# noised input audio and context mixing embedding
|
74 |
-
|
75 |
-
class InputEmbedding(nn.Module):
|
76 |
-
def __init__(self, mel_dim, text_dim, out_dim):
|
77 |
-
super().__init__()
|
78 |
-
self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
|
79 |
-
self.conv_pos_embed = ConvPositionEmbedding(dim = out_dim)
|
80 |
-
|
81 |
-
def forward(self, x: float['b n d'], cond: float['b n d'], text_embed: float['b n d'], drop_audio_cond = False):
|
82 |
-
if drop_audio_cond: # cfg for cond audio
|
83 |
-
cond = torch.zeros_like(cond)
|
84 |
-
|
85 |
-
x = self.proj(torch.cat((x, cond, text_embed), dim = -1))
|
86 |
-
x = self.conv_pos_embed(x) + x
|
87 |
-
return x
|
88 |
-
|
89 |
-
|
90 |
-
# Flat UNet Transformer backbone
|
91 |
-
|
92 |
-
class UNetT(nn.Module):
|
93 |
-
def __init__(self, *,
|
94 |
-
dim, depth = 8, heads = 8, dim_head = 64, dropout = 0.1, ff_mult = 4,
|
95 |
-
mel_dim = 100, text_num_embeds = 256, text_dim = None, conv_layers = 0,
|
96 |
-
skip_connect_type: Literal['add', 'concat', 'none'] = 'concat',
|
97 |
-
):
|
98 |
-
super().__init__()
|
99 |
-
assert depth % 2 == 0, "UNet-Transformer's depth should be even."
|
100 |
-
|
101 |
-
self.time_embed = TimestepEmbedding(dim)
|
102 |
-
if text_dim is None:
|
103 |
-
text_dim = mel_dim
|
104 |
-
self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers = conv_layers)
|
105 |
-
self.input_embed = InputEmbedding(mel_dim, text_dim, dim)
|
106 |
-
|
107 |
-
self.rotary_embed = RotaryEmbedding(dim_head)
|
108 |
-
|
109 |
-
# transformer layers & skip connections
|
110 |
-
|
111 |
-
self.dim = dim
|
112 |
-
self.skip_connect_type = skip_connect_type
|
113 |
-
needs_skip_proj = skip_connect_type == 'concat'
|
114 |
-
|
115 |
-
self.depth = depth
|
116 |
-
self.layers = nn.ModuleList([])
|
117 |
-
|
118 |
-
for idx in range(depth):
|
119 |
-
is_later_half = idx >= (depth // 2)
|
120 |
-
|
121 |
-
attn_norm = RMSNorm(dim)
|
122 |
-
attn = Attention(
|
123 |
-
processor = AttnProcessor(),
|
124 |
-
dim = dim,
|
125 |
-
heads = heads,
|
126 |
-
dim_head = dim_head,
|
127 |
-
dropout = dropout,
|
128 |
-
)
|
129 |
-
|
130 |
-
ff_norm = RMSNorm(dim)
|
131 |
-
ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
132 |
-
|
133 |
-
skip_proj = nn.Linear(dim * 2, dim, bias = False) if needs_skip_proj and is_later_half else None
|
134 |
-
|
135 |
-
self.layers.append(nn.ModuleList([
|
136 |
-
skip_proj,
|
137 |
-
attn_norm,
|
138 |
-
attn,
|
139 |
-
ff_norm,
|
140 |
-
ff,
|
141 |
-
]))
|
142 |
-
|
143 |
-
self.norm_out = RMSNorm(dim)
|
144 |
-
self.proj_out = nn.Linear(dim, mel_dim)
|
145 |
-
|
146 |
-
def forward(
|
147 |
-
self,
|
148 |
-
x: float['b n d'], # nosied input audio
|
149 |
-
cond: float['b n d'], # masked cond audio
|
150 |
-
text: int['b nt'], # text
|
151 |
-
time: float['b'] | float[''], # time step
|
152 |
-
drop_audio_cond, # cfg for cond audio
|
153 |
-
drop_text, # cfg for text
|
154 |
-
mask: bool['b n'] | None = None,
|
155 |
-
):
|
156 |
-
batch, seq_len = x.shape[0], x.shape[1]
|
157 |
-
if time.ndim == 0:
|
158 |
-
time = repeat(time, ' -> b', b = batch)
|
159 |
-
|
160 |
-
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
161 |
-
t = self.time_embed(time)
|
162 |
-
text_embed = self.text_embed(text, seq_len, drop_text = drop_text)
|
163 |
-
x = self.input_embed(x, cond, text_embed, drop_audio_cond = drop_audio_cond)
|
164 |
-
|
165 |
-
# postfix time t to input x, [b n d] -> [b n+1 d]
|
166 |
-
x, ps = pack((t, x), 'b * d')
|
167 |
-
if mask is not None:
|
168 |
-
mask = F.pad(mask, (1, 0), value=1)
|
169 |
-
|
170 |
-
rope = self.rotary_embed.forward_from_seq_len(seq_len + 1)
|
171 |
-
|
172 |
-
# flat unet transformer
|
173 |
-
skip_connect_type = self.skip_connect_type
|
174 |
-
skips = []
|
175 |
-
for idx, (maybe_skip_proj, attn_norm, attn, ff_norm, ff) in enumerate(self.layers):
|
176 |
-
layer = idx + 1
|
177 |
-
|
178 |
-
# skip connection logic
|
179 |
-
is_first_half = layer <= (self.depth // 2)
|
180 |
-
is_later_half = not is_first_half
|
181 |
-
|
182 |
-
if is_first_half:
|
183 |
-
skips.append(x)
|
184 |
-
|
185 |
-
if is_later_half:
|
186 |
-
skip = skips.pop()
|
187 |
-
if skip_connect_type == 'concat':
|
188 |
-
x = torch.cat((x, skip), dim = -1)
|
189 |
-
x = maybe_skip_proj(x)
|
190 |
-
elif skip_connect_type == 'add':
|
191 |
-
x = x + skip
|
192 |
-
|
193 |
-
# attention and feedforward blocks
|
194 |
-
x = attn(attn_norm(x), rope = rope, mask = mask) + x
|
195 |
-
x = ff(ff_norm(x)) + x
|
196 |
-
|
197 |
-
assert len(skips) == 0
|
198 |
-
|
199 |
-
_, x = unpack(self.norm_out(x), ps, 'b * d')
|
200 |
-
|
201 |
-
return self.proj_out(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/cfm.py
DELETED
@@ -1,279 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ein notation:
|
3 |
-
b - batch
|
4 |
-
n - sequence
|
5 |
-
nt - text sequence
|
6 |
-
nw - raw wave length
|
7 |
-
d - dimension
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import annotations
|
11 |
-
from typing import Callable
|
12 |
-
from random import random
|
13 |
-
|
14 |
-
import torch
|
15 |
-
from torch import nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from torch.nn.utils.rnn import pad_sequence
|
18 |
-
|
19 |
-
from torchdiffeq import odeint
|
20 |
-
|
21 |
-
from einops import rearrange
|
22 |
-
|
23 |
-
from model.modules import MelSpec
|
24 |
-
|
25 |
-
from model.utils import (
|
26 |
-
default, exists,
|
27 |
-
list_str_to_idx, list_str_to_tensor,
|
28 |
-
lens_to_mask, mask_from_frac_lengths,
|
29 |
-
)
|
30 |
-
|
31 |
-
|
32 |
-
class CFM(nn.Module):
|
33 |
-
def __init__(
|
34 |
-
self,
|
35 |
-
transformer: nn.Module,
|
36 |
-
sigma = 0.,
|
37 |
-
odeint_kwargs: dict = dict(
|
38 |
-
# atol = 1e-5,
|
39 |
-
# rtol = 1e-5,
|
40 |
-
method = 'euler' # 'midpoint'
|
41 |
-
),
|
42 |
-
audio_drop_prob = 0.3,
|
43 |
-
cond_drop_prob = 0.2,
|
44 |
-
num_channels = None,
|
45 |
-
mel_spec_module: nn.Module | None = None,
|
46 |
-
mel_spec_kwargs: dict = dict(),
|
47 |
-
frac_lengths_mask: tuple[float, float] = (0.7, 1.),
|
48 |
-
vocab_char_map: dict[str: int] | None = None
|
49 |
-
):
|
50 |
-
super().__init__()
|
51 |
-
|
52 |
-
self.frac_lengths_mask = frac_lengths_mask
|
53 |
-
|
54 |
-
# mel spec
|
55 |
-
self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
|
56 |
-
num_channels = default(num_channels, self.mel_spec.n_mel_channels)
|
57 |
-
self.num_channels = num_channels
|
58 |
-
|
59 |
-
# classifier-free guidance
|
60 |
-
self.audio_drop_prob = audio_drop_prob
|
61 |
-
self.cond_drop_prob = cond_drop_prob
|
62 |
-
|
63 |
-
# transformer
|
64 |
-
self.transformer = transformer
|
65 |
-
dim = transformer.dim
|
66 |
-
self.dim = dim
|
67 |
-
|
68 |
-
# conditional flow related
|
69 |
-
self.sigma = sigma
|
70 |
-
|
71 |
-
# sampling related
|
72 |
-
self.odeint_kwargs = odeint_kwargs
|
73 |
-
|
74 |
-
# vocab map for tokenization
|
75 |
-
self.vocab_char_map = vocab_char_map
|
76 |
-
|
77 |
-
@property
|
78 |
-
def device(self):
|
79 |
-
return next(self.parameters()).device
|
80 |
-
|
81 |
-
@torch.no_grad()
|
82 |
-
def sample(
|
83 |
-
self,
|
84 |
-
cond: float['b n d'] | float['b nw'],
|
85 |
-
text: int['b nt'] | list[str],
|
86 |
-
duration: int | int['b'],
|
87 |
-
*,
|
88 |
-
lens: int['b'] | None = None,
|
89 |
-
steps = 32,
|
90 |
-
cfg_strength = 1.,
|
91 |
-
sway_sampling_coef = None,
|
92 |
-
seed: int | None = None,
|
93 |
-
max_duration = 4096,
|
94 |
-
vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
|
95 |
-
no_ref_audio = False,
|
96 |
-
duplicate_test = False,
|
97 |
-
t_inter = 0.1,
|
98 |
-
edit_mask = None,
|
99 |
-
):
|
100 |
-
self.eval()
|
101 |
-
|
102 |
-
# raw wave
|
103 |
-
|
104 |
-
if cond.ndim == 2:
|
105 |
-
cond = self.mel_spec(cond)
|
106 |
-
cond = rearrange(cond, 'b d n -> b n d')
|
107 |
-
assert cond.shape[-1] == self.num_channels
|
108 |
-
|
109 |
-
batch, cond_seq_len, device = *cond.shape[:2], cond.device
|
110 |
-
if not exists(lens):
|
111 |
-
lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
|
112 |
-
|
113 |
-
# text
|
114 |
-
|
115 |
-
if isinstance(text, list):
|
116 |
-
if exists(self.vocab_char_map):
|
117 |
-
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
118 |
-
else:
|
119 |
-
text = list_str_to_tensor(text).to(device)
|
120 |
-
assert text.shape[0] == batch
|
121 |
-
|
122 |
-
if exists(text):
|
123 |
-
text_lens = (text != -1).sum(dim = -1)
|
124 |
-
lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
|
125 |
-
|
126 |
-
# duration
|
127 |
-
|
128 |
-
cond_mask = lens_to_mask(lens)
|
129 |
-
if edit_mask is not None:
|
130 |
-
cond_mask = cond_mask & edit_mask
|
131 |
-
|
132 |
-
if isinstance(duration, int):
|
133 |
-
duration = torch.full((batch,), duration, device = device, dtype = torch.long)
|
134 |
-
|
135 |
-
duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
|
136 |
-
duration = duration.clamp(max = max_duration)
|
137 |
-
max_duration = duration.amax()
|
138 |
-
|
139 |
-
# duplicate test corner for inner time step oberservation
|
140 |
-
if duplicate_test:
|
141 |
-
test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
|
142 |
-
|
143 |
-
cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
|
144 |
-
cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
|
145 |
-
cond_mask = rearrange(cond_mask, '... -> ... 1')
|
146 |
-
step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
|
147 |
-
|
148 |
-
if batch > 1:
|
149 |
-
mask = lens_to_mask(duration)
|
150 |
-
else: # save memory and speed up, as single inference need no mask currently
|
151 |
-
mask = None
|
152 |
-
|
153 |
-
# test for no ref audio
|
154 |
-
if no_ref_audio:
|
155 |
-
cond = torch.zeros_like(cond)
|
156 |
-
|
157 |
-
# neural ode
|
158 |
-
|
159 |
-
def fn(t, x):
|
160 |
-
# at each step, conditioning is fixed
|
161 |
-
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
162 |
-
|
163 |
-
# predict flow
|
164 |
-
pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
|
165 |
-
if cfg_strength < 1e-5:
|
166 |
-
return pred
|
167 |
-
|
168 |
-
null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
|
169 |
-
return pred + (pred - null_pred) * cfg_strength
|
170 |
-
|
171 |
-
# noise input
|
172 |
-
# to make sure batch inference result is same with different batch size, and for sure single inference
|
173 |
-
# still some difference maybe due to convolutional layers
|
174 |
-
y0 = []
|
175 |
-
for dur in duration:
|
176 |
-
if exists(seed):
|
177 |
-
torch.manual_seed(seed)
|
178 |
-
y0.append(torch.randn(dur, self.num_channels, device = self.device))
|
179 |
-
y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
|
180 |
-
|
181 |
-
t_start = 0
|
182 |
-
|
183 |
-
# duplicate test corner for inner time step oberservation
|
184 |
-
if duplicate_test:
|
185 |
-
t_start = t_inter
|
186 |
-
y0 = (1 - t_start) * y0 + t_start * test_cond
|
187 |
-
steps = int(steps * (1 - t_start))
|
188 |
-
|
189 |
-
t = torch.linspace(t_start, 1, steps, device = self.device)
|
190 |
-
if sway_sampling_coef is not None:
|
191 |
-
t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
|
192 |
-
|
193 |
-
trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
|
194 |
-
|
195 |
-
sampled = trajectory[-1]
|
196 |
-
out = sampled
|
197 |
-
out = torch.where(cond_mask, cond, out)
|
198 |
-
|
199 |
-
if exists(vocoder):
|
200 |
-
out = rearrange(out, 'b n d -> b d n')
|
201 |
-
out = vocoder(out)
|
202 |
-
|
203 |
-
return out, trajectory
|
204 |
-
|
205 |
-
def forward(
|
206 |
-
self,
|
207 |
-
inp: float['b n d'] | float['b nw'], # mel or raw wave
|
208 |
-
text: int['b nt'] | list[str],
|
209 |
-
*,
|
210 |
-
lens: int['b'] | None = None,
|
211 |
-
noise_scheduler: str | None = None,
|
212 |
-
):
|
213 |
-
# handle raw wave
|
214 |
-
if inp.ndim == 2:
|
215 |
-
inp = self.mel_spec(inp)
|
216 |
-
inp = rearrange(inp, 'b d n -> b n d')
|
217 |
-
assert inp.shape[-1] == self.num_channels
|
218 |
-
|
219 |
-
batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
|
220 |
-
|
221 |
-
# handle text as string
|
222 |
-
if isinstance(text, list):
|
223 |
-
if exists(self.vocab_char_map):
|
224 |
-
text = list_str_to_idx(text, self.vocab_char_map).to(device)
|
225 |
-
else:
|
226 |
-
text = list_str_to_tensor(text).to(device)
|
227 |
-
assert text.shape[0] == batch
|
228 |
-
|
229 |
-
# lens and mask
|
230 |
-
if not exists(lens):
|
231 |
-
lens = torch.full((batch,), seq_len, device = device)
|
232 |
-
|
233 |
-
mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
|
234 |
-
|
235 |
-
# get a random span to mask out for training conditionally
|
236 |
-
frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
|
237 |
-
rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
|
238 |
-
|
239 |
-
if exists(mask):
|
240 |
-
rand_span_mask &= mask
|
241 |
-
|
242 |
-
# mel is x1
|
243 |
-
x1 = inp
|
244 |
-
|
245 |
-
# x0 is gaussian noise
|
246 |
-
x0 = torch.randn_like(x1)
|
247 |
-
|
248 |
-
# time step
|
249 |
-
time = torch.rand((batch,), dtype = dtype, device = self.device)
|
250 |
-
# TODO. noise_scheduler
|
251 |
-
|
252 |
-
# sample xt (φ_t(x) in the paper)
|
253 |
-
t = rearrange(time, 'b -> b 1 1')
|
254 |
-
φ = (1 - t) * x0 + t * x1
|
255 |
-
flow = x1 - x0
|
256 |
-
|
257 |
-
# only predict what is within the random mask span for infilling
|
258 |
-
cond = torch.where(
|
259 |
-
rand_span_mask[..., None],
|
260 |
-
torch.zeros_like(x1), x1
|
261 |
-
)
|
262 |
-
|
263 |
-
# transformer and cfg training with a drop rate
|
264 |
-
drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
|
265 |
-
if random() < self.cond_drop_prob: # p_uncond in voicebox paper
|
266 |
-
drop_audio_cond = True
|
267 |
-
drop_text = True
|
268 |
-
else:
|
269 |
-
drop_text = False
|
270 |
-
|
271 |
-
# if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
|
272 |
-
# adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
|
273 |
-
pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
|
274 |
-
|
275 |
-
# flow matching loss
|
276 |
-
loss = F.mse_loss(pred, flow, reduction = 'none')
|
277 |
-
loss = loss[rand_span_mask]
|
278 |
-
|
279 |
-
return loss.mean(), cond, pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/dataset.py
DELETED
@@ -1,257 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import random
|
3 |
-
from tqdm import tqdm
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch.utils.data import Dataset, Sampler
|
8 |
-
import torchaudio
|
9 |
-
from datasets import load_dataset, load_from_disk
|
10 |
-
from datasets import Dataset as Dataset_
|
11 |
-
|
12 |
-
from einops import rearrange
|
13 |
-
|
14 |
-
from model.modules import MelSpec
|
15 |
-
|
16 |
-
|
17 |
-
class HFDataset(Dataset):
|
18 |
-
def __init__(
|
19 |
-
self,
|
20 |
-
hf_dataset: Dataset,
|
21 |
-
target_sample_rate = 24_000,
|
22 |
-
n_mel_channels = 100,
|
23 |
-
hop_length = 256,
|
24 |
-
):
|
25 |
-
self.data = hf_dataset
|
26 |
-
self.target_sample_rate = target_sample_rate
|
27 |
-
self.hop_length = hop_length
|
28 |
-
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
29 |
-
|
30 |
-
def get_frame_len(self, index):
|
31 |
-
row = self.data[index]
|
32 |
-
audio = row['audio']['array']
|
33 |
-
sample_rate = row['audio']['sampling_rate']
|
34 |
-
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
35 |
-
|
36 |
-
def __len__(self):
|
37 |
-
return len(self.data)
|
38 |
-
|
39 |
-
def __getitem__(self, index):
|
40 |
-
row = self.data[index]
|
41 |
-
audio = row['audio']['array']
|
42 |
-
|
43 |
-
# logger.info(f"Audio shape: {audio.shape}")
|
44 |
-
|
45 |
-
sample_rate = row['audio']['sampling_rate']
|
46 |
-
duration = audio.shape[-1] / sample_rate
|
47 |
-
|
48 |
-
if duration > 30 or duration < 0.3:
|
49 |
-
return self.__getitem__((index + 1) % len(self.data))
|
50 |
-
|
51 |
-
audio_tensor = torch.from_numpy(audio).float()
|
52 |
-
|
53 |
-
if sample_rate != self.target_sample_rate:
|
54 |
-
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
55 |
-
audio_tensor = resampler(audio_tensor)
|
56 |
-
|
57 |
-
audio_tensor = rearrange(audio_tensor, 't -> 1 t')
|
58 |
-
|
59 |
-
mel_spec = self.mel_spectrogram(audio_tensor)
|
60 |
-
|
61 |
-
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
62 |
-
|
63 |
-
text = row['text']
|
64 |
-
|
65 |
-
return dict(
|
66 |
-
mel_spec = mel_spec,
|
67 |
-
text = text,
|
68 |
-
)
|
69 |
-
|
70 |
-
|
71 |
-
class CustomDataset(Dataset):
|
72 |
-
def __init__(
|
73 |
-
self,
|
74 |
-
custom_dataset: Dataset,
|
75 |
-
durations = None,
|
76 |
-
target_sample_rate = 24_000,
|
77 |
-
hop_length = 256,
|
78 |
-
n_mel_channels = 100,
|
79 |
-
preprocessed_mel = False,
|
80 |
-
):
|
81 |
-
self.data = custom_dataset
|
82 |
-
self.durations = durations
|
83 |
-
self.target_sample_rate = target_sample_rate
|
84 |
-
self.hop_length = hop_length
|
85 |
-
self.preprocessed_mel = preprocessed_mel
|
86 |
-
if not preprocessed_mel:
|
87 |
-
self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
|
88 |
-
|
89 |
-
def get_frame_len(self, index):
|
90 |
-
if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
91 |
-
return self.durations[index] * self.target_sample_rate / self.hop_length
|
92 |
-
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
93 |
-
|
94 |
-
def __len__(self):
|
95 |
-
return len(self.data)
|
96 |
-
|
97 |
-
def __getitem__(self, index):
|
98 |
-
row = self.data[index]
|
99 |
-
audio_path = row["audio_path"]
|
100 |
-
text = row["text"]
|
101 |
-
duration = row["duration"]
|
102 |
-
|
103 |
-
if self.preprocessed_mel:
|
104 |
-
mel_spec = torch.tensor(row["mel_spec"])
|
105 |
-
|
106 |
-
else:
|
107 |
-
audio, source_sample_rate = torchaudio.load(audio_path)
|
108 |
-
|
109 |
-
if duration > 30 or duration < 0.3:
|
110 |
-
return self.__getitem__((index + 1) % len(self.data))
|
111 |
-
|
112 |
-
if source_sample_rate != self.target_sample_rate:
|
113 |
-
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
114 |
-
audio = resampler(audio)
|
115 |
-
|
116 |
-
mel_spec = self.mel_spectrogram(audio)
|
117 |
-
mel_spec = rearrange(mel_spec, '1 d t -> d t')
|
118 |
-
|
119 |
-
return dict(
|
120 |
-
mel_spec = mel_spec,
|
121 |
-
text = text,
|
122 |
-
)
|
123 |
-
|
124 |
-
|
125 |
-
# Dynamic Batch Sampler
|
126 |
-
|
127 |
-
class DynamicBatchSampler(Sampler[list[int]]):
|
128 |
-
""" Extension of Sampler that will do the following:
|
129 |
-
1. Change the batch size (essentially number of sequences)
|
130 |
-
in a batch to ensure that the total number of frames are less
|
131 |
-
than a certain threshold.
|
132 |
-
2. Make sure the padding efficiency in the batch is high.
|
133 |
-
"""
|
134 |
-
|
135 |
-
def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
|
136 |
-
self.sampler = sampler
|
137 |
-
self.frames_threshold = frames_threshold
|
138 |
-
self.max_samples = max_samples
|
139 |
-
|
140 |
-
indices, batches = [], []
|
141 |
-
data_source = self.sampler.data_source
|
142 |
-
|
143 |
-
for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
|
144 |
-
indices.append((idx, data_source.get_frame_len(idx)))
|
145 |
-
indices.sort(key=lambda elem : elem[1])
|
146 |
-
|
147 |
-
batch = []
|
148 |
-
batch_frames = 0
|
149 |
-
for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
|
150 |
-
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
151 |
-
batch.append(idx)
|
152 |
-
batch_frames += frame_len
|
153 |
-
else:
|
154 |
-
if len(batch) > 0:
|
155 |
-
batches.append(batch)
|
156 |
-
if frame_len <= self.frames_threshold:
|
157 |
-
batch = [idx]
|
158 |
-
batch_frames = frame_len
|
159 |
-
else:
|
160 |
-
batch = []
|
161 |
-
batch_frames = 0
|
162 |
-
|
163 |
-
if not drop_last and len(batch) > 0:
|
164 |
-
batches.append(batch)
|
165 |
-
|
166 |
-
del indices
|
167 |
-
|
168 |
-
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
169 |
-
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
170 |
-
# e.g. for epoch n, use (random_seed + n)
|
171 |
-
random.seed(random_seed)
|
172 |
-
random.shuffle(batches)
|
173 |
-
|
174 |
-
self.batches = batches
|
175 |
-
|
176 |
-
def __iter__(self):
|
177 |
-
return iter(self.batches)
|
178 |
-
|
179 |
-
def __len__(self):
|
180 |
-
return len(self.batches)
|
181 |
-
|
182 |
-
|
183 |
-
# Load dataset
|
184 |
-
|
185 |
-
def load_dataset(
|
186 |
-
dataset_name: str,
|
187 |
-
tokenizer: str = "pinyin",
|
188 |
-
dataset_type: str = "CustomDataset",
|
189 |
-
audio_type: str = "raw",
|
190 |
-
mel_spec_kwargs: dict = dict()
|
191 |
-
) -> CustomDataset | HFDataset:
|
192 |
-
'''
|
193 |
-
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
194 |
-
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
195 |
-
'''
|
196 |
-
|
197 |
-
print("Loading dataset ...")
|
198 |
-
|
199 |
-
if dataset_type == "CustomDataset":
|
200 |
-
if audio_type == "raw":
|
201 |
-
try:
|
202 |
-
train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
|
203 |
-
except:
|
204 |
-
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
|
205 |
-
preprocessed_mel = False
|
206 |
-
elif audio_type == "mel":
|
207 |
-
train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
|
208 |
-
preprocessed_mel = True
|
209 |
-
with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
|
210 |
-
data_dict = json.load(f)
|
211 |
-
durations = data_dict["duration"]
|
212 |
-
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
213 |
-
|
214 |
-
elif dataset_type == "CustomDatasetPath":
|
215 |
-
try:
|
216 |
-
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
217 |
-
except:
|
218 |
-
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
219 |
-
|
220 |
-
with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
|
221 |
-
data_dict = json.load(f)
|
222 |
-
durations = data_dict["duration"]
|
223 |
-
train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
|
224 |
-
|
225 |
-
elif dataset_type == "HFDataset":
|
226 |
-
print("Should manually modify the path of huggingface dataset to your need.\n" +
|
227 |
-
"May also the corresponding script cuz different dataset may have different format.")
|
228 |
-
pre, post = dataset_name.split("_")
|
229 |
-
train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
|
230 |
-
|
231 |
-
return train_dataset
|
232 |
-
|
233 |
-
|
234 |
-
# collation
|
235 |
-
|
236 |
-
def collate_fn(batch):
|
237 |
-
mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
|
238 |
-
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
239 |
-
max_mel_length = mel_lengths.amax()
|
240 |
-
|
241 |
-
padded_mel_specs = []
|
242 |
-
for spec in mel_specs: # TODO. maybe records mask for attention here
|
243 |
-
padding = (0, max_mel_length - spec.size(-1))
|
244 |
-
padded_spec = F.pad(spec, padding, value = 0)
|
245 |
-
padded_mel_specs.append(padded_spec)
|
246 |
-
|
247 |
-
mel_specs = torch.stack(padded_mel_specs)
|
248 |
-
|
249 |
-
text = [item['text'] for item in batch]
|
250 |
-
text_lengths = torch.LongTensor([len(item) for item in text])
|
251 |
-
|
252 |
-
return dict(
|
253 |
-
mel = mel_specs,
|
254 |
-
mel_lengths = mel_lengths,
|
255 |
-
text = text,
|
256 |
-
text_lengths = text_lengths,
|
257 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/ecapa_tdnn.py
DELETED
@@ -1,268 +0,0 @@
|
|
1 |
-
# just for speaker similarity evaluation, third-party code
|
2 |
-
|
3 |
-
# From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
|
4 |
-
# part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
|
5 |
-
|
6 |
-
import os
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
|
12 |
-
''' Res2Conv1d + BatchNorm1d + ReLU
|
13 |
-
'''
|
14 |
-
|
15 |
-
class Res2Conv1dReluBn(nn.Module):
|
16 |
-
'''
|
17 |
-
in_channels == out_channels == channels
|
18 |
-
'''
|
19 |
-
|
20 |
-
def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
|
21 |
-
super().__init__()
|
22 |
-
assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
|
23 |
-
self.scale = scale
|
24 |
-
self.width = channels // scale
|
25 |
-
self.nums = scale if scale == 1 else scale - 1
|
26 |
-
|
27 |
-
self.convs = []
|
28 |
-
self.bns = []
|
29 |
-
for i in range(self.nums):
|
30 |
-
self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
|
31 |
-
self.bns.append(nn.BatchNorm1d(self.width))
|
32 |
-
self.convs = nn.ModuleList(self.convs)
|
33 |
-
self.bns = nn.ModuleList(self.bns)
|
34 |
-
|
35 |
-
def forward(self, x):
|
36 |
-
out = []
|
37 |
-
spx = torch.split(x, self.width, 1)
|
38 |
-
for i in range(self.nums):
|
39 |
-
if i == 0:
|
40 |
-
sp = spx[i]
|
41 |
-
else:
|
42 |
-
sp = sp + spx[i]
|
43 |
-
# Order: conv -> relu -> bn
|
44 |
-
sp = self.convs[i](sp)
|
45 |
-
sp = self.bns[i](F.relu(sp))
|
46 |
-
out.append(sp)
|
47 |
-
if self.scale != 1:
|
48 |
-
out.append(spx[self.nums])
|
49 |
-
out = torch.cat(out, dim=1)
|
50 |
-
|
51 |
-
return out
|
52 |
-
|
53 |
-
|
54 |
-
''' Conv1d + BatchNorm1d + ReLU
|
55 |
-
'''
|
56 |
-
|
57 |
-
class Conv1dReluBn(nn.Module):
|
58 |
-
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
|
59 |
-
super().__init__()
|
60 |
-
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
|
61 |
-
self.bn = nn.BatchNorm1d(out_channels)
|
62 |
-
|
63 |
-
def forward(self, x):
|
64 |
-
return self.bn(F.relu(self.conv(x)))
|
65 |
-
|
66 |
-
|
67 |
-
''' The SE connection of 1D case.
|
68 |
-
'''
|
69 |
-
|
70 |
-
class SE_Connect(nn.Module):
|
71 |
-
def __init__(self, channels, se_bottleneck_dim=128):
|
72 |
-
super().__init__()
|
73 |
-
self.linear1 = nn.Linear(channels, se_bottleneck_dim)
|
74 |
-
self.linear2 = nn.Linear(se_bottleneck_dim, channels)
|
75 |
-
|
76 |
-
def forward(self, x):
|
77 |
-
out = x.mean(dim=2)
|
78 |
-
out = F.relu(self.linear1(out))
|
79 |
-
out = torch.sigmoid(self.linear2(out))
|
80 |
-
out = x * out.unsqueeze(2)
|
81 |
-
|
82 |
-
return out
|
83 |
-
|
84 |
-
|
85 |
-
''' SE-Res2Block of the ECAPA-TDNN architecture.
|
86 |
-
'''
|
87 |
-
|
88 |
-
# def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
|
89 |
-
# return nn.Sequential(
|
90 |
-
# Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
|
91 |
-
# Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
|
92 |
-
# Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
|
93 |
-
# SE_Connect(channels)
|
94 |
-
# )
|
95 |
-
|
96 |
-
class SE_Res2Block(nn.Module):
|
97 |
-
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
|
98 |
-
super().__init__()
|
99 |
-
self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
100 |
-
self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
|
101 |
-
self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
102 |
-
self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
|
103 |
-
|
104 |
-
self.shortcut = None
|
105 |
-
if in_channels != out_channels:
|
106 |
-
self.shortcut = nn.Conv1d(
|
107 |
-
in_channels=in_channels,
|
108 |
-
out_channels=out_channels,
|
109 |
-
kernel_size=1,
|
110 |
-
)
|
111 |
-
|
112 |
-
def forward(self, x):
|
113 |
-
residual = x
|
114 |
-
if self.shortcut:
|
115 |
-
residual = self.shortcut(x)
|
116 |
-
|
117 |
-
x = self.Conv1dReluBn1(x)
|
118 |
-
x = self.Res2Conv1dReluBn(x)
|
119 |
-
x = self.Conv1dReluBn2(x)
|
120 |
-
x = self.SE_Connect(x)
|
121 |
-
|
122 |
-
return x + residual
|
123 |
-
|
124 |
-
|
125 |
-
''' Attentive weighted mean and standard deviation pooling.
|
126 |
-
'''
|
127 |
-
|
128 |
-
class AttentiveStatsPool(nn.Module):
|
129 |
-
def __init__(self, in_dim, attention_channels=128, global_context_att=False):
|
130 |
-
super().__init__()
|
131 |
-
self.global_context_att = global_context_att
|
132 |
-
|
133 |
-
# Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
|
134 |
-
if global_context_att:
|
135 |
-
self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
|
136 |
-
else:
|
137 |
-
self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
|
138 |
-
self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
|
139 |
-
|
140 |
-
def forward(self, x):
|
141 |
-
|
142 |
-
if self.global_context_att:
|
143 |
-
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
144 |
-
context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
145 |
-
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
146 |
-
else:
|
147 |
-
x_in = x
|
148 |
-
|
149 |
-
# DON'T use ReLU here! In experiments, I find ReLU hard to converge.
|
150 |
-
alpha = torch.tanh(self.linear1(x_in))
|
151 |
-
# alpha = F.relu(self.linear1(x_in))
|
152 |
-
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
153 |
-
mean = torch.sum(alpha * x, dim=2)
|
154 |
-
residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
|
155 |
-
std = torch.sqrt(residuals.clamp(min=1e-9))
|
156 |
-
return torch.cat([mean, std], dim=1)
|
157 |
-
|
158 |
-
|
159 |
-
class ECAPA_TDNN(nn.Module):
|
160 |
-
def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
|
161 |
-
feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
162 |
-
super().__init__()
|
163 |
-
|
164 |
-
self.feat_type = feat_type
|
165 |
-
self.feature_selection = feature_selection
|
166 |
-
self.update_extract = update_extract
|
167 |
-
self.sr = sr
|
168 |
-
|
169 |
-
torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
|
170 |
-
try:
|
171 |
-
local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
|
172 |
-
self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
|
173 |
-
except:
|
174 |
-
self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
|
175 |
-
|
176 |
-
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
|
177 |
-
self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
|
178 |
-
if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
|
179 |
-
self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
|
180 |
-
|
181 |
-
self.feat_num = self.get_feat_num()
|
182 |
-
self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
|
183 |
-
|
184 |
-
if feat_type != 'fbank' and feat_type != 'mfcc':
|
185 |
-
freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
|
186 |
-
for name, param in self.feature_extract.named_parameters():
|
187 |
-
for freeze_val in freeze_list:
|
188 |
-
if freeze_val in name:
|
189 |
-
param.requires_grad = False
|
190 |
-
break
|
191 |
-
|
192 |
-
if not self.update_extract:
|
193 |
-
for param in self.feature_extract.parameters():
|
194 |
-
param.requires_grad = False
|
195 |
-
|
196 |
-
self.instance_norm = nn.InstanceNorm1d(feat_dim)
|
197 |
-
# self.channels = [channels] * 4 + [channels * 3]
|
198 |
-
self.channels = [channels] * 4 + [1536]
|
199 |
-
|
200 |
-
self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
|
201 |
-
self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
|
202 |
-
self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
|
203 |
-
self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
|
204 |
-
|
205 |
-
# self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
|
206 |
-
cat_channels = channels * 3
|
207 |
-
self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
|
208 |
-
self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
|
209 |
-
self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
|
210 |
-
self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
|
211 |
-
|
212 |
-
|
213 |
-
def get_feat_num(self):
|
214 |
-
self.feature_extract.eval()
|
215 |
-
wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
|
216 |
-
with torch.no_grad():
|
217 |
-
features = self.feature_extract(wav)
|
218 |
-
select_feature = features[self.feature_selection]
|
219 |
-
if isinstance(select_feature, (list, tuple)):
|
220 |
-
return len(select_feature)
|
221 |
-
else:
|
222 |
-
return 1
|
223 |
-
|
224 |
-
def get_feat(self, x):
|
225 |
-
if self.update_extract:
|
226 |
-
x = self.feature_extract([sample for sample in x])
|
227 |
-
else:
|
228 |
-
with torch.no_grad():
|
229 |
-
if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
|
230 |
-
x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
|
231 |
-
else:
|
232 |
-
x = self.feature_extract([sample for sample in x])
|
233 |
-
|
234 |
-
if self.feat_type == 'fbank':
|
235 |
-
x = x.log()
|
236 |
-
|
237 |
-
if self.feat_type != "fbank" and self.feat_type != "mfcc":
|
238 |
-
x = x[self.feature_selection]
|
239 |
-
if isinstance(x, (list, tuple)):
|
240 |
-
x = torch.stack(x, dim=0)
|
241 |
-
else:
|
242 |
-
x = x.unsqueeze(0)
|
243 |
-
norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
244 |
-
x = (norm_weights * x).sum(dim=0)
|
245 |
-
x = torch.transpose(x, 1, 2) + 1e-6
|
246 |
-
|
247 |
-
x = self.instance_norm(x)
|
248 |
-
return x
|
249 |
-
|
250 |
-
def forward(self, x):
|
251 |
-
x = self.get_feat(x)
|
252 |
-
|
253 |
-
out1 = self.layer1(x)
|
254 |
-
out2 = self.layer2(out1)
|
255 |
-
out3 = self.layer3(out2)
|
256 |
-
out4 = self.layer4(out3)
|
257 |
-
|
258 |
-
out = torch.cat([out2, out3, out4], dim=1)
|
259 |
-
out = F.relu(self.conv(out))
|
260 |
-
out = self.bn(self.pooling(out))
|
261 |
-
out = self.linear(out)
|
262 |
-
|
263 |
-
return out
|
264 |
-
|
265 |
-
|
266 |
-
def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
|
267 |
-
return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
|
268 |
-
feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/modules.py
DELETED
@@ -1,574 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
ein notation:
|
3 |
-
b - batch
|
4 |
-
n - sequence
|
5 |
-
nt - text sequence
|
6 |
-
nw - raw wave length
|
7 |
-
d - dimension
|
8 |
-
"""
|
9 |
-
|
10 |
-
from __future__ import annotations
|
11 |
-
from typing import Optional
|
12 |
-
import math
|
13 |
-
|
14 |
-
import torch
|
15 |
-
from torch import nn
|
16 |
-
import torch.nn.functional as F
|
17 |
-
import torchaudio
|
18 |
-
|
19 |
-
from einops import rearrange
|
20 |
-
from x_transformers.x_transformers import apply_rotary_pos_emb
|
21 |
-
|
22 |
-
|
23 |
-
# raw wav to mel spec
|
24 |
-
|
25 |
-
class MelSpec(nn.Module):
|
26 |
-
def __init__(
|
27 |
-
self,
|
28 |
-
filter_length = 1024,
|
29 |
-
hop_length = 256,
|
30 |
-
win_length = 1024,
|
31 |
-
n_mel_channels = 100,
|
32 |
-
target_sample_rate = 24_000,
|
33 |
-
normalize = False,
|
34 |
-
power = 1,
|
35 |
-
norm = None,
|
36 |
-
center = True,
|
37 |
-
):
|
38 |
-
super().__init__()
|
39 |
-
self.n_mel_channels = n_mel_channels
|
40 |
-
|
41 |
-
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
42 |
-
sample_rate = target_sample_rate,
|
43 |
-
n_fft = filter_length,
|
44 |
-
win_length = win_length,
|
45 |
-
hop_length = hop_length,
|
46 |
-
n_mels = n_mel_channels,
|
47 |
-
power = power,
|
48 |
-
center = center,
|
49 |
-
normalized = normalize,
|
50 |
-
norm = norm,
|
51 |
-
)
|
52 |
-
|
53 |
-
self.register_buffer('dummy', torch.tensor(0), persistent = False)
|
54 |
-
|
55 |
-
def forward(self, inp):
|
56 |
-
if len(inp.shape) == 3:
|
57 |
-
inp = rearrange(inp, 'b 1 nw -> b nw')
|
58 |
-
|
59 |
-
assert len(inp.shape) == 2
|
60 |
-
|
61 |
-
if self.dummy.device != inp.device:
|
62 |
-
self.to(inp.device)
|
63 |
-
|
64 |
-
mel = self.mel_stft(inp)
|
65 |
-
mel = mel.clamp(min = 1e-5).log()
|
66 |
-
return mel
|
67 |
-
|
68 |
-
|
69 |
-
# sinusoidal position embedding
|
70 |
-
|
71 |
-
class SinusPositionEmbedding(nn.Module):
|
72 |
-
def __init__(self, dim):
|
73 |
-
super().__init__()
|
74 |
-
self.dim = dim
|
75 |
-
|
76 |
-
def forward(self, x, scale=1000):
|
77 |
-
device = x.device
|
78 |
-
half_dim = self.dim // 2
|
79 |
-
emb = math.log(10000) / (half_dim - 1)
|
80 |
-
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
81 |
-
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
82 |
-
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
83 |
-
return emb
|
84 |
-
|
85 |
-
|
86 |
-
# convolutional position embedding
|
87 |
-
|
88 |
-
class ConvPositionEmbedding(nn.Module):
|
89 |
-
def __init__(self, dim, kernel_size = 31, groups = 16):
|
90 |
-
super().__init__()
|
91 |
-
assert kernel_size % 2 != 0
|
92 |
-
self.conv1d = nn.Sequential(
|
93 |
-
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
94 |
-
nn.Mish(),
|
95 |
-
nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
|
96 |
-
nn.Mish(),
|
97 |
-
)
|
98 |
-
|
99 |
-
def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
|
100 |
-
if mask is not None:
|
101 |
-
mask = mask[..., None]
|
102 |
-
x = x.masked_fill(~mask, 0.)
|
103 |
-
|
104 |
-
x = rearrange(x, 'b n d -> b d n')
|
105 |
-
x = self.conv1d(x)
|
106 |
-
out = rearrange(x, 'b d n -> b n d')
|
107 |
-
|
108 |
-
if mask is not None:
|
109 |
-
out = out.masked_fill(~mask, 0.)
|
110 |
-
|
111 |
-
return out
|
112 |
-
|
113 |
-
|
114 |
-
# rotary positional embedding related
|
115 |
-
|
116 |
-
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
|
117 |
-
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
118 |
-
# has some connection to NTK literature
|
119 |
-
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
120 |
-
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
121 |
-
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
122 |
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
123 |
-
t = torch.arange(end, device=freqs.device) # type: ignore
|
124 |
-
freqs = torch.outer(t, freqs).float() # type: ignore
|
125 |
-
freqs_cos = torch.cos(freqs) # real part
|
126 |
-
freqs_sin = torch.sin(freqs) # imaginary part
|
127 |
-
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
128 |
-
|
129 |
-
def get_pos_embed_indices(start, length, max_pos, scale=1.):
|
130 |
-
# length = length if isinstance(length, int) else length.max()
|
131 |
-
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
132 |
-
pos = start.unsqueeze(1) + (
|
133 |
-
torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
|
134 |
-
scale.unsqueeze(1)).long()
|
135 |
-
# avoid extra long error.
|
136 |
-
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
137 |
-
return pos
|
138 |
-
|
139 |
-
|
140 |
-
# Global Response Normalization layer (Instance Normalization ?)
|
141 |
-
|
142 |
-
class GRN(nn.Module):
|
143 |
-
def __init__(self, dim):
|
144 |
-
super().__init__()
|
145 |
-
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
146 |
-
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
147 |
-
|
148 |
-
def forward(self, x):
|
149 |
-
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
150 |
-
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
151 |
-
return self.gamma * (x * Nx) + self.beta + x
|
152 |
-
|
153 |
-
|
154 |
-
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
155 |
-
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
156 |
-
|
157 |
-
class ConvNeXtV2Block(nn.Module):
|
158 |
-
def __init__(
|
159 |
-
self,
|
160 |
-
dim: int,
|
161 |
-
intermediate_dim: int,
|
162 |
-
dilation: int = 1,
|
163 |
-
):
|
164 |
-
super().__init__()
|
165 |
-
padding = (dilation * (7 - 1)) // 2
|
166 |
-
self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
|
167 |
-
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
168 |
-
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
169 |
-
self.act = nn.GELU()
|
170 |
-
self.grn = GRN(intermediate_dim)
|
171 |
-
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
172 |
-
|
173 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
174 |
-
residual = x
|
175 |
-
x = x.transpose(1, 2) # b n d -> b d n
|
176 |
-
x = self.dwconv(x)
|
177 |
-
x = x.transpose(1, 2) # b d n -> b n d
|
178 |
-
x = self.norm(x)
|
179 |
-
x = self.pwconv1(x)
|
180 |
-
x = self.act(x)
|
181 |
-
x = self.grn(x)
|
182 |
-
x = self.pwconv2(x)
|
183 |
-
return residual + x
|
184 |
-
|
185 |
-
|
186 |
-
# AdaLayerNormZero
|
187 |
-
# return with modulated x for attn input, and params for later mlp modulation
|
188 |
-
|
189 |
-
class AdaLayerNormZero(nn.Module):
|
190 |
-
def __init__(self, dim):
|
191 |
-
super().__init__()
|
192 |
-
|
193 |
-
self.silu = nn.SiLU()
|
194 |
-
self.linear = nn.Linear(dim, dim * 6)
|
195 |
-
|
196 |
-
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
197 |
-
|
198 |
-
def forward(self, x, emb = None):
|
199 |
-
emb = self.linear(self.silu(emb))
|
200 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
201 |
-
|
202 |
-
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
203 |
-
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
204 |
-
|
205 |
-
|
206 |
-
# AdaLayerNormZero for final layer
|
207 |
-
# return only with modulated x for attn input, cuz no more mlp modulation
|
208 |
-
|
209 |
-
class AdaLayerNormZero_Final(nn.Module):
|
210 |
-
def __init__(self, dim):
|
211 |
-
super().__init__()
|
212 |
-
|
213 |
-
self.silu = nn.SiLU()
|
214 |
-
self.linear = nn.Linear(dim, dim * 2)
|
215 |
-
|
216 |
-
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
217 |
-
|
218 |
-
def forward(self, x, emb):
|
219 |
-
emb = self.linear(self.silu(emb))
|
220 |
-
scale, shift = torch.chunk(emb, 2, dim=1)
|
221 |
-
|
222 |
-
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
223 |
-
return x
|
224 |
-
|
225 |
-
|
226 |
-
# FeedForward
|
227 |
-
|
228 |
-
class FeedForward(nn.Module):
|
229 |
-
def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
|
230 |
-
super().__init__()
|
231 |
-
inner_dim = int(dim * mult)
|
232 |
-
dim_out = dim_out if dim_out is not None else dim
|
233 |
-
|
234 |
-
activation = nn.GELU(approximate=approximate)
|
235 |
-
project_in = nn.Sequential(
|
236 |
-
nn.Linear(dim, inner_dim),
|
237 |
-
activation
|
238 |
-
)
|
239 |
-
self.ff = nn.Sequential(
|
240 |
-
project_in,
|
241 |
-
nn.Dropout(dropout),
|
242 |
-
nn.Linear(inner_dim, dim_out)
|
243 |
-
)
|
244 |
-
|
245 |
-
def forward(self, x):
|
246 |
-
return self.ff(x)
|
247 |
-
|
248 |
-
|
249 |
-
# Attention with possible joint part
|
250 |
-
# modified from diffusers/src/diffusers/models/attention_processor.py
|
251 |
-
|
252 |
-
class Attention(nn.Module):
|
253 |
-
def __init__(
|
254 |
-
self,
|
255 |
-
processor: JointAttnProcessor | AttnProcessor,
|
256 |
-
dim: int,
|
257 |
-
heads: int = 8,
|
258 |
-
dim_head: int = 64,
|
259 |
-
dropout: float = 0.0,
|
260 |
-
context_dim: Optional[int] = None, # if not None -> joint attention
|
261 |
-
context_pre_only = None,
|
262 |
-
):
|
263 |
-
super().__init__()
|
264 |
-
|
265 |
-
if not hasattr(F, "scaled_dot_product_attention"):
|
266 |
-
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
267 |
-
|
268 |
-
self.processor = processor
|
269 |
-
|
270 |
-
self.dim = dim
|
271 |
-
self.heads = heads
|
272 |
-
self.inner_dim = dim_head * heads
|
273 |
-
self.dropout = dropout
|
274 |
-
|
275 |
-
self.context_dim = context_dim
|
276 |
-
self.context_pre_only = context_pre_only
|
277 |
-
|
278 |
-
self.to_q = nn.Linear(dim, self.inner_dim)
|
279 |
-
self.to_k = nn.Linear(dim, self.inner_dim)
|
280 |
-
self.to_v = nn.Linear(dim, self.inner_dim)
|
281 |
-
|
282 |
-
if self.context_dim is not None:
|
283 |
-
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
284 |
-
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
285 |
-
if self.context_pre_only is not None:
|
286 |
-
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
287 |
-
|
288 |
-
self.to_out = nn.ModuleList([])
|
289 |
-
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
290 |
-
self.to_out.append(nn.Dropout(dropout))
|
291 |
-
|
292 |
-
if self.context_pre_only is not None and not self.context_pre_only:
|
293 |
-
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
294 |
-
|
295 |
-
def forward(
|
296 |
-
self,
|
297 |
-
x: float['b n d'], # noised input x
|
298 |
-
c: float['b n d'] = None, # context c
|
299 |
-
mask: bool['b n'] | None = None,
|
300 |
-
rope = None, # rotary position embedding for x
|
301 |
-
c_rope = None, # rotary position embedding for c
|
302 |
-
) -> torch.Tensor:
|
303 |
-
if c is not None:
|
304 |
-
return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
|
305 |
-
else:
|
306 |
-
return self.processor(self, x, mask = mask, rope = rope)
|
307 |
-
|
308 |
-
|
309 |
-
# Attention processor
|
310 |
-
|
311 |
-
class AttnProcessor:
|
312 |
-
def __init__(self):
|
313 |
-
pass
|
314 |
-
|
315 |
-
def __call__(
|
316 |
-
self,
|
317 |
-
attn: Attention,
|
318 |
-
x: float['b n d'], # noised input x
|
319 |
-
mask: bool['b n'] | None = None,
|
320 |
-
rope = None, # rotary position embedding
|
321 |
-
) -> torch.FloatTensor:
|
322 |
-
|
323 |
-
batch_size = x.shape[0]
|
324 |
-
|
325 |
-
# `sample` projections.
|
326 |
-
query = attn.to_q(x)
|
327 |
-
key = attn.to_k(x)
|
328 |
-
value = attn.to_v(x)
|
329 |
-
|
330 |
-
# apply rotary position embedding
|
331 |
-
if rope is not None:
|
332 |
-
freqs, xpos_scale = rope
|
333 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
334 |
-
|
335 |
-
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
336 |
-
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
337 |
-
|
338 |
-
# attention
|
339 |
-
inner_dim = key.shape[-1]
|
340 |
-
head_dim = inner_dim // attn.heads
|
341 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
342 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
343 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
344 |
-
|
345 |
-
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
346 |
-
if mask is not None:
|
347 |
-
attn_mask = mask
|
348 |
-
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
|
349 |
-
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
350 |
-
else:
|
351 |
-
attn_mask = None
|
352 |
-
|
353 |
-
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
354 |
-
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
355 |
-
x = x.to(query.dtype)
|
356 |
-
|
357 |
-
# linear proj
|
358 |
-
x = attn.to_out[0](x)
|
359 |
-
# dropout
|
360 |
-
x = attn.to_out[1](x)
|
361 |
-
|
362 |
-
if mask is not None:
|
363 |
-
mask = rearrange(mask, 'b n -> b n 1')
|
364 |
-
x = x.masked_fill(~mask, 0.)
|
365 |
-
|
366 |
-
return x
|
367 |
-
|
368 |
-
|
369 |
-
# Joint Attention processor for MM-DiT
|
370 |
-
# modified from diffusers/src/diffusers/models/attention_processor.py
|
371 |
-
|
372 |
-
class JointAttnProcessor:
|
373 |
-
def __init__(self):
|
374 |
-
pass
|
375 |
-
|
376 |
-
def __call__(
|
377 |
-
self,
|
378 |
-
attn: Attention,
|
379 |
-
x: float['b n d'], # noised input x
|
380 |
-
c: float['b nt d'] = None, # context c, here text
|
381 |
-
mask: bool['b n'] | None = None,
|
382 |
-
rope = None, # rotary position embedding for x
|
383 |
-
c_rope = None, # rotary position embedding for c
|
384 |
-
) -> torch.FloatTensor:
|
385 |
-
residual = x
|
386 |
-
|
387 |
-
batch_size = c.shape[0]
|
388 |
-
|
389 |
-
# `sample` projections.
|
390 |
-
query = attn.to_q(x)
|
391 |
-
key = attn.to_k(x)
|
392 |
-
value = attn.to_v(x)
|
393 |
-
|
394 |
-
# `context` projections.
|
395 |
-
c_query = attn.to_q_c(c)
|
396 |
-
c_key = attn.to_k_c(c)
|
397 |
-
c_value = attn.to_v_c(c)
|
398 |
-
|
399 |
-
# apply rope for context and noised input independently
|
400 |
-
if rope is not None:
|
401 |
-
freqs, xpos_scale = rope
|
402 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
403 |
-
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
404 |
-
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
405 |
-
if c_rope is not None:
|
406 |
-
freqs, xpos_scale = c_rope
|
407 |
-
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
|
408 |
-
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
409 |
-
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
410 |
-
|
411 |
-
# attention
|
412 |
-
query = torch.cat([query, c_query], dim=1)
|
413 |
-
key = torch.cat([key, c_key], dim=1)
|
414 |
-
value = torch.cat([value, c_value], dim=1)
|
415 |
-
|
416 |
-
inner_dim = key.shape[-1]
|
417 |
-
head_dim = inner_dim // attn.heads
|
418 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
419 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
420 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
421 |
-
|
422 |
-
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
423 |
-
if mask is not None:
|
424 |
-
attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
|
425 |
-
attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
|
426 |
-
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
427 |
-
else:
|
428 |
-
attn_mask = None
|
429 |
-
|
430 |
-
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
431 |
-
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
432 |
-
x = x.to(query.dtype)
|
433 |
-
|
434 |
-
# Split the attention outputs.
|
435 |
-
x, c = (
|
436 |
-
x[:, :residual.shape[1]],
|
437 |
-
x[:, residual.shape[1]:],
|
438 |
-
)
|
439 |
-
|
440 |
-
# linear proj
|
441 |
-
x = attn.to_out[0](x)
|
442 |
-
# dropout
|
443 |
-
x = attn.to_out[1](x)
|
444 |
-
if not attn.context_pre_only:
|
445 |
-
c = attn.to_out_c(c)
|
446 |
-
|
447 |
-
if mask is not None:
|
448 |
-
mask = rearrange(mask, 'b n -> b n 1')
|
449 |
-
x = x.masked_fill(~mask, 0.)
|
450 |
-
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
451 |
-
|
452 |
-
return x, c
|
453 |
-
|
454 |
-
|
455 |
-
# DiT Block
|
456 |
-
|
457 |
-
class DiTBlock(nn.Module):
|
458 |
-
|
459 |
-
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
|
460 |
-
super().__init__()
|
461 |
-
|
462 |
-
self.attn_norm = AdaLayerNormZero(dim)
|
463 |
-
self.attn = Attention(
|
464 |
-
processor = AttnProcessor(),
|
465 |
-
dim = dim,
|
466 |
-
heads = heads,
|
467 |
-
dim_head = dim_head,
|
468 |
-
dropout = dropout,
|
469 |
-
)
|
470 |
-
|
471 |
-
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
472 |
-
self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
473 |
-
|
474 |
-
def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
|
475 |
-
# pre-norm & modulation for attention input
|
476 |
-
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
477 |
-
|
478 |
-
# attention
|
479 |
-
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
480 |
-
|
481 |
-
# process attention output for input x
|
482 |
-
x = x + gate_msa.unsqueeze(1) * attn_output
|
483 |
-
|
484 |
-
norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
485 |
-
ff_output = self.ff(norm)
|
486 |
-
x = x + gate_mlp.unsqueeze(1) * ff_output
|
487 |
-
|
488 |
-
return x
|
489 |
-
|
490 |
-
|
491 |
-
# MMDiT Block https://arxiv.org/abs/2403.03206
|
492 |
-
|
493 |
-
class MMDiTBlock(nn.Module):
|
494 |
-
r"""
|
495 |
-
modified from diffusers/src/diffusers/models/attention.py
|
496 |
-
notes.
|
497 |
-
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
498 |
-
_x: noised input related. (right part)
|
499 |
-
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
500 |
-
"""
|
501 |
-
|
502 |
-
def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
|
503 |
-
super().__init__()
|
504 |
-
|
505 |
-
self.context_pre_only = context_pre_only
|
506 |
-
|
507 |
-
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
508 |
-
self.attn_norm_x = AdaLayerNormZero(dim)
|
509 |
-
self.attn = Attention(
|
510 |
-
processor = JointAttnProcessor(),
|
511 |
-
dim = dim,
|
512 |
-
heads = heads,
|
513 |
-
dim_head = dim_head,
|
514 |
-
dropout = dropout,
|
515 |
-
context_dim = dim,
|
516 |
-
context_pre_only = context_pre_only,
|
517 |
-
)
|
518 |
-
|
519 |
-
if not context_pre_only:
|
520 |
-
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
521 |
-
self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
522 |
-
else:
|
523 |
-
self.ff_norm_c = None
|
524 |
-
self.ff_c = None
|
525 |
-
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
526 |
-
self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
|
527 |
-
|
528 |
-
def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
|
529 |
-
# pre-norm & modulation for attention input
|
530 |
-
if self.context_pre_only:
|
531 |
-
norm_c = self.attn_norm_c(c, t)
|
532 |
-
else:
|
533 |
-
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
534 |
-
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
535 |
-
|
536 |
-
# attention
|
537 |
-
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
538 |
-
|
539 |
-
# process attention output for context c
|
540 |
-
if self.context_pre_only:
|
541 |
-
c = None
|
542 |
-
else: # if not last layer
|
543 |
-
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
544 |
-
|
545 |
-
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
546 |
-
c_ff_output = self.ff_c(norm_c)
|
547 |
-
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
548 |
-
|
549 |
-
# process attention output for input x
|
550 |
-
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
551 |
-
|
552 |
-
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
553 |
-
x_ff_output = self.ff_x(norm_x)
|
554 |
-
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
555 |
-
|
556 |
-
return c, x
|
557 |
-
|
558 |
-
|
559 |
-
# time step conditioning embedding
|
560 |
-
|
561 |
-
class TimestepEmbedding(nn.Module):
|
562 |
-
def __init__(self, dim, freq_embed_dim=256):
|
563 |
-
super().__init__()
|
564 |
-
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
565 |
-
self.time_mlp = nn.Sequential(
|
566 |
-
nn.Linear(freq_embed_dim, dim),
|
567 |
-
nn.SiLU(),
|
568 |
-
nn.Linear(dim, dim)
|
569 |
-
)
|
570 |
-
|
571 |
-
def forward(self, timestep: float['b']):
|
572 |
-
time_hidden = self.time_embed(timestep)
|
573 |
-
time = self.time_mlp(time_hidden) # b d
|
574 |
-
return time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/trainer.py
DELETED
@@ -1,250 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import os
|
4 |
-
import gc
|
5 |
-
from tqdm import tqdm
|
6 |
-
import wandb
|
7 |
-
|
8 |
-
import torch
|
9 |
-
from torch.optim import AdamW
|
10 |
-
from torch.utils.data import DataLoader, Dataset, SequentialSampler
|
11 |
-
from torch.optim.lr_scheduler import LinearLR, SequentialLR
|
12 |
-
|
13 |
-
from einops import rearrange
|
14 |
-
|
15 |
-
from accelerate import Accelerator
|
16 |
-
from accelerate.utils import DistributedDataParallelKwargs
|
17 |
-
|
18 |
-
from ema_pytorch import EMA
|
19 |
-
|
20 |
-
from model import CFM
|
21 |
-
from model.utils import exists, default
|
22 |
-
from model.dataset import DynamicBatchSampler, collate_fn
|
23 |
-
|
24 |
-
|
25 |
-
# trainer
|
26 |
-
|
27 |
-
class Trainer:
|
28 |
-
def __init__(
|
29 |
-
self,
|
30 |
-
model: CFM,
|
31 |
-
epochs,
|
32 |
-
learning_rate,
|
33 |
-
num_warmup_updates = 20000,
|
34 |
-
save_per_updates = 1000,
|
35 |
-
checkpoint_path = None,
|
36 |
-
batch_size = 32,
|
37 |
-
batch_size_type: str = "sample",
|
38 |
-
max_samples = 32,
|
39 |
-
grad_accumulation_steps = 1,
|
40 |
-
max_grad_norm = 1.0,
|
41 |
-
noise_scheduler: str | None = None,
|
42 |
-
duration_predictor: torch.nn.Module | None = None,
|
43 |
-
wandb_project = "test_e2-tts",
|
44 |
-
wandb_run_name = "test_run",
|
45 |
-
wandb_resume_id: str = None,
|
46 |
-
last_per_steps = None,
|
47 |
-
accelerate_kwargs: dict = dict(),
|
48 |
-
ema_kwargs: dict = dict()
|
49 |
-
):
|
50 |
-
|
51 |
-
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
|
52 |
-
|
53 |
-
self.accelerator = Accelerator(
|
54 |
-
log_with = "wandb",
|
55 |
-
kwargs_handlers = [ddp_kwargs],
|
56 |
-
gradient_accumulation_steps = grad_accumulation_steps,
|
57 |
-
**accelerate_kwargs
|
58 |
-
)
|
59 |
-
|
60 |
-
if exists(wandb_resume_id):
|
61 |
-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
|
62 |
-
else:
|
63 |
-
init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
|
64 |
-
self.accelerator.init_trackers(
|
65 |
-
project_name = wandb_project,
|
66 |
-
init_kwargs=init_kwargs,
|
67 |
-
config={"epochs": epochs,
|
68 |
-
"learning_rate": learning_rate,
|
69 |
-
"num_warmup_updates": num_warmup_updates,
|
70 |
-
"batch_size": batch_size,
|
71 |
-
"batch_size_type": batch_size_type,
|
72 |
-
"max_samples": max_samples,
|
73 |
-
"grad_accumulation_steps": grad_accumulation_steps,
|
74 |
-
"max_grad_norm": max_grad_norm,
|
75 |
-
"gpus": self.accelerator.num_processes,
|
76 |
-
"noise_scheduler": noise_scheduler}
|
77 |
-
)
|
78 |
-
|
79 |
-
self.model = model
|
80 |
-
|
81 |
-
if self.is_main:
|
82 |
-
self.ema_model = EMA(
|
83 |
-
model,
|
84 |
-
include_online_model = False,
|
85 |
-
**ema_kwargs
|
86 |
-
)
|
87 |
-
|
88 |
-
self.ema_model.to(self.accelerator.device)
|
89 |
-
|
90 |
-
self.epochs = epochs
|
91 |
-
self.num_warmup_updates = num_warmup_updates
|
92 |
-
self.save_per_updates = save_per_updates
|
93 |
-
self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
|
94 |
-
self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
|
95 |
-
|
96 |
-
self.batch_size = batch_size
|
97 |
-
self.batch_size_type = batch_size_type
|
98 |
-
self.max_samples = max_samples
|
99 |
-
self.grad_accumulation_steps = grad_accumulation_steps
|
100 |
-
self.max_grad_norm = max_grad_norm
|
101 |
-
|
102 |
-
self.noise_scheduler = noise_scheduler
|
103 |
-
|
104 |
-
self.duration_predictor = duration_predictor
|
105 |
-
|
106 |
-
self.optimizer = AdamW(model.parameters(), lr=learning_rate)
|
107 |
-
self.model, self.optimizer = self.accelerator.prepare(
|
108 |
-
self.model, self.optimizer
|
109 |
-
)
|
110 |
-
|
111 |
-
@property
|
112 |
-
def is_main(self):
|
113 |
-
return self.accelerator.is_main_process
|
114 |
-
|
115 |
-
def save_checkpoint(self, step, last=False):
|
116 |
-
self.accelerator.wait_for_everyone()
|
117 |
-
if self.is_main:
|
118 |
-
checkpoint = dict(
|
119 |
-
model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
|
120 |
-
optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
|
121 |
-
ema_model_state_dict = self.ema_model.state_dict(),
|
122 |
-
scheduler_state_dict = self.scheduler.state_dict(),
|
123 |
-
step = step
|
124 |
-
)
|
125 |
-
if not os.path.exists(self.checkpoint_path):
|
126 |
-
os.makedirs(self.checkpoint_path)
|
127 |
-
if last == True:
|
128 |
-
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
|
129 |
-
print(f"Saved last checkpoint at step {step}")
|
130 |
-
else:
|
131 |
-
self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
|
132 |
-
|
133 |
-
def load_checkpoint(self):
|
134 |
-
if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
|
135 |
-
return 0
|
136 |
-
|
137 |
-
self.accelerator.wait_for_everyone()
|
138 |
-
if "model_last.pt" in os.listdir(self.checkpoint_path):
|
139 |
-
latest_checkpoint = "model_last.pt"
|
140 |
-
else:
|
141 |
-
latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
|
142 |
-
# checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
|
143 |
-
checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
|
144 |
-
|
145 |
-
if self.is_main:
|
146 |
-
self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
147 |
-
|
148 |
-
if 'step' in checkpoint:
|
149 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
150 |
-
self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
|
151 |
-
if self.scheduler:
|
152 |
-
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
153 |
-
step = checkpoint['step']
|
154 |
-
else:
|
155 |
-
checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
|
156 |
-
self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
|
157 |
-
step = 0
|
158 |
-
|
159 |
-
del checkpoint; gc.collect()
|
160 |
-
return step
|
161 |
-
|
162 |
-
def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
|
163 |
-
|
164 |
-
if exists(resumable_with_seed):
|
165 |
-
generator = torch.Generator()
|
166 |
-
generator.manual_seed(resumable_with_seed)
|
167 |
-
else:
|
168 |
-
generator = None
|
169 |
-
|
170 |
-
if self.batch_size_type == "sample":
|
171 |
-
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
|
172 |
-
batch_size=self.batch_size, shuffle=True, generator=generator)
|
173 |
-
elif self.batch_size_type == "frame":
|
174 |
-
self.accelerator.even_batches = False
|
175 |
-
sampler = SequentialSampler(train_dataset)
|
176 |
-
batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
|
177 |
-
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
|
178 |
-
batch_sampler=batch_sampler)
|
179 |
-
else:
|
180 |
-
raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
|
181 |
-
|
182 |
-
# accelerator.prepare() dispatches batches to devices;
|
183 |
-
# which means the length of dataloader calculated before, should consider the number of devices
|
184 |
-
warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
|
185 |
-
# otherwise by default with split_batches=False, warmup steps change with num_processes
|
186 |
-
total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
|
187 |
-
decay_steps = total_steps - warmup_steps
|
188 |
-
warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
|
189 |
-
decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
|
190 |
-
self.scheduler = SequentialLR(self.optimizer,
|
191 |
-
schedulers=[warmup_scheduler, decay_scheduler],
|
192 |
-
milestones=[warmup_steps])
|
193 |
-
train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
|
194 |
-
start_step = self.load_checkpoint()
|
195 |
-
global_step = start_step
|
196 |
-
|
197 |
-
if exists(resumable_with_seed):
|
198 |
-
orig_epoch_step = len(train_dataloader)
|
199 |
-
skipped_epoch = int(start_step // orig_epoch_step)
|
200 |
-
skipped_batch = start_step % orig_epoch_step
|
201 |
-
skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
|
202 |
-
else:
|
203 |
-
skipped_epoch = 0
|
204 |
-
|
205 |
-
for epoch in range(skipped_epoch, self.epochs):
|
206 |
-
self.model.train()
|
207 |
-
if exists(resumable_with_seed) and epoch == skipped_epoch:
|
208 |
-
progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
|
209 |
-
initial=skipped_batch, total=orig_epoch_step)
|
210 |
-
else:
|
211 |
-
progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
|
212 |
-
|
213 |
-
for batch in progress_bar:
|
214 |
-
with self.accelerator.accumulate(self.model):
|
215 |
-
text_inputs = batch['text']
|
216 |
-
mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
|
217 |
-
mel_lengths = batch["mel_lengths"]
|
218 |
-
|
219 |
-
# TODO. add duration predictor training
|
220 |
-
if self.duration_predictor is not None and self.accelerator.is_local_main_process:
|
221 |
-
dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
|
222 |
-
self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
|
223 |
-
|
224 |
-
loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
|
225 |
-
self.accelerator.backward(loss)
|
226 |
-
|
227 |
-
if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
|
228 |
-
self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
229 |
-
|
230 |
-
self.optimizer.step()
|
231 |
-
self.scheduler.step()
|
232 |
-
self.optimizer.zero_grad()
|
233 |
-
|
234 |
-
if self.is_main:
|
235 |
-
self.ema_model.update()
|
236 |
-
|
237 |
-
global_step += 1
|
238 |
-
|
239 |
-
if self.accelerator.is_local_main_process:
|
240 |
-
self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
|
241 |
-
|
242 |
-
progress_bar.set_postfix(step=str(global_step), loss=loss.item())
|
243 |
-
|
244 |
-
if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
|
245 |
-
self.save_checkpoint(global_step)
|
246 |
-
|
247 |
-
if global_step % self.last_per_steps == 0:
|
248 |
-
self.save_checkpoint(global_step, last=True)
|
249 |
-
|
250 |
-
self.accelerator.end_training()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model/utils.py
DELETED
@@ -1,580 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import os
|
4 |
-
import re
|
5 |
-
import math
|
6 |
-
import random
|
7 |
-
import string
|
8 |
-
from tqdm import tqdm
|
9 |
-
from collections import defaultdict
|
10 |
-
|
11 |
-
import matplotlib
|
12 |
-
matplotlib.use("Agg")
|
13 |
-
import matplotlib.pylab as plt
|
14 |
-
|
15 |
-
import torch
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from torch.nn.utils.rnn import pad_sequence
|
18 |
-
import torchaudio
|
19 |
-
|
20 |
-
import einx
|
21 |
-
from einops import rearrange, reduce
|
22 |
-
|
23 |
-
import jieba
|
24 |
-
from pypinyin import lazy_pinyin, Style
|
25 |
-
|
26 |
-
from model.ecapa_tdnn import ECAPA_TDNN_SMALL
|
27 |
-
from model.modules import MelSpec
|
28 |
-
|
29 |
-
|
30 |
-
# seed everything
|
31 |
-
|
32 |
-
def seed_everything(seed = 0):
|
33 |
-
random.seed(seed)
|
34 |
-
os.environ['PYTHONHASHSEED'] = str(seed)
|
35 |
-
torch.manual_seed(seed)
|
36 |
-
torch.cuda.manual_seed(seed)
|
37 |
-
torch.cuda.manual_seed_all(seed)
|
38 |
-
torch.backends.cudnn.deterministic = True
|
39 |
-
torch.backends.cudnn.benchmark = False
|
40 |
-
|
41 |
-
# helpers
|
42 |
-
|
43 |
-
def exists(v):
|
44 |
-
return v is not None
|
45 |
-
|
46 |
-
def default(v, d):
|
47 |
-
return v if exists(v) else d
|
48 |
-
|
49 |
-
# tensor helpers
|
50 |
-
|
51 |
-
def lens_to_mask(
|
52 |
-
t: int['b'],
|
53 |
-
length: int | None = None
|
54 |
-
) -> bool['b n']:
|
55 |
-
|
56 |
-
if not exists(length):
|
57 |
-
length = t.amax()
|
58 |
-
|
59 |
-
seq = torch.arange(length, device = t.device)
|
60 |
-
return einx.less('n, b -> b n', seq, t)
|
61 |
-
|
62 |
-
def mask_from_start_end_indices(
|
63 |
-
seq_len: int['b'],
|
64 |
-
start: int['b'],
|
65 |
-
end: int['b']
|
66 |
-
):
|
67 |
-
max_seq_len = seq_len.max().item()
|
68 |
-
seq = torch.arange(max_seq_len, device = start.device).long()
|
69 |
-
return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
|
70 |
-
|
71 |
-
def mask_from_frac_lengths(
|
72 |
-
seq_len: int['b'],
|
73 |
-
frac_lengths: float['b']
|
74 |
-
):
|
75 |
-
lengths = (frac_lengths * seq_len).long()
|
76 |
-
max_start = seq_len - lengths
|
77 |
-
|
78 |
-
rand = torch.rand_like(frac_lengths)
|
79 |
-
start = (max_start * rand).long().clamp(min = 0)
|
80 |
-
end = start + lengths
|
81 |
-
|
82 |
-
return mask_from_start_end_indices(seq_len, start, end)
|
83 |
-
|
84 |
-
def maybe_masked_mean(
|
85 |
-
t: float['b n d'],
|
86 |
-
mask: bool['b n'] = None
|
87 |
-
) -> float['b d']:
|
88 |
-
|
89 |
-
if not exists(mask):
|
90 |
-
return t.mean(dim = 1)
|
91 |
-
|
92 |
-
t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
|
93 |
-
num = reduce(t, 'b n d -> b d', 'sum')
|
94 |
-
den = reduce(mask.float(), 'b n -> b', 'sum')
|
95 |
-
|
96 |
-
return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
|
97 |
-
|
98 |
-
|
99 |
-
# simple utf-8 tokenizer, since paper went character based
|
100 |
-
def list_str_to_tensor(
|
101 |
-
text: list[str],
|
102 |
-
padding_value = -1
|
103 |
-
) -> int['b nt']:
|
104 |
-
list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
|
105 |
-
text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
|
106 |
-
return text
|
107 |
-
|
108 |
-
# char tokenizer, based on custom dataset's extracted .txt file
|
109 |
-
def list_str_to_idx(
|
110 |
-
text: list[str] | list[list[str]],
|
111 |
-
vocab_char_map: dict[str, int], # {char: idx}
|
112 |
-
padding_value = -1
|
113 |
-
) -> int['b nt']:
|
114 |
-
list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
|
115 |
-
text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
|
116 |
-
return text
|
117 |
-
|
118 |
-
|
119 |
-
# Get tokenizer
|
120 |
-
|
121 |
-
def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
|
122 |
-
'''
|
123 |
-
tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
|
124 |
-
- "char" for char-wise tokenizer, need .txt vocab_file
|
125 |
-
- "byte" for utf-8 tokenizer
|
126 |
-
- "custom" if you're directly passing in a path to the vocab.txt you want to use
|
127 |
-
vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
|
128 |
-
- if use "char", derived from unfiltered character & symbol counts of custom dataset
|
129 |
-
- if use "byte", set to 256 (unicode byte range)
|
130 |
-
'''
|
131 |
-
if tokenizer in ["pinyin", "char"]:
|
132 |
-
with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
|
133 |
-
vocab_char_map = {}
|
134 |
-
for i, char in enumerate(f):
|
135 |
-
vocab_char_map[char[:-1]] = i
|
136 |
-
vocab_size = len(vocab_char_map)
|
137 |
-
assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
|
138 |
-
|
139 |
-
elif tokenizer == "byte":
|
140 |
-
vocab_char_map = None
|
141 |
-
vocab_size = 256
|
142 |
-
elif tokenizer == "custom":
|
143 |
-
with open (dataset_name, "r", encoding="utf-8") as f:
|
144 |
-
vocab_char_map = {}
|
145 |
-
for i, char in enumerate(f):
|
146 |
-
vocab_char_map[char[:-1]] = i
|
147 |
-
vocab_size = len(vocab_char_map)
|
148 |
-
|
149 |
-
return vocab_char_map, vocab_size
|
150 |
-
|
151 |
-
|
152 |
-
# convert char to pinyin
|
153 |
-
|
154 |
-
def convert_char_to_pinyin(text_list, polyphone = True):
|
155 |
-
final_text_list = []
|
156 |
-
god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
|
157 |
-
custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
|
158 |
-
for text in text_list:
|
159 |
-
char_list = []
|
160 |
-
text = text.translate(god_knows_why_en_testset_contains_zh_quote)
|
161 |
-
text = text.translate(custom_trans)
|
162 |
-
for seg in jieba.cut(text):
|
163 |
-
seg_byte_len = len(bytes(seg, 'UTF-8'))
|
164 |
-
if seg_byte_len == len(seg): # if pure alphabets and symbols
|
165 |
-
if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
|
166 |
-
char_list.append(" ")
|
167 |
-
char_list.extend(seg)
|
168 |
-
elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
|
169 |
-
seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
|
170 |
-
for c in seg:
|
171 |
-
if c not in "。,、;:?!《》【】—…":
|
172 |
-
char_list.append(" ")
|
173 |
-
char_list.append(c)
|
174 |
-
else: # if mixed chinese characters, alphabets and symbols
|
175 |
-
for c in seg:
|
176 |
-
if ord(c) < 256:
|
177 |
-
char_list.extend(c)
|
178 |
-
else:
|
179 |
-
if c not in "。,、;:?!《》【】—…":
|
180 |
-
char_list.append(" ")
|
181 |
-
char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
|
182 |
-
else: # if is zh punc
|
183 |
-
char_list.append(c)
|
184 |
-
final_text_list.append(char_list)
|
185 |
-
|
186 |
-
return final_text_list
|
187 |
-
|
188 |
-
|
189 |
-
# save spectrogram
|
190 |
-
def save_spectrogram(spectrogram, path):
|
191 |
-
plt.figure(figsize=(12, 4))
|
192 |
-
plt.imshow(spectrogram, origin='lower', aspect='auto')
|
193 |
-
plt.colorbar()
|
194 |
-
plt.savefig(path)
|
195 |
-
plt.close()
|
196 |
-
|
197 |
-
|
198 |
-
# seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
199 |
-
def get_seedtts_testset_metainfo(metalst):
|
200 |
-
f = open(metalst); lines = f.readlines(); f.close()
|
201 |
-
metainfo = []
|
202 |
-
for line in lines:
|
203 |
-
if len(line.strip().split('|')) == 5:
|
204 |
-
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
205 |
-
elif len(line.strip().split('|')) == 4:
|
206 |
-
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
207 |
-
gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
|
208 |
-
if not os.path.isabs(prompt_wav):
|
209 |
-
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
210 |
-
metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
|
211 |
-
return metainfo
|
212 |
-
|
213 |
-
|
214 |
-
# librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
|
215 |
-
def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
|
216 |
-
f = open(metalst); lines = f.readlines(); f.close()
|
217 |
-
metainfo = []
|
218 |
-
for line in lines:
|
219 |
-
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
220 |
-
|
221 |
-
# ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
222 |
-
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
223 |
-
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
224 |
-
|
225 |
-
# gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
|
226 |
-
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
227 |
-
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
228 |
-
|
229 |
-
metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
|
230 |
-
|
231 |
-
return metainfo
|
232 |
-
|
233 |
-
|
234 |
-
# padded to max length mel batch
|
235 |
-
def padded_mel_batch(ref_mels):
|
236 |
-
max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
|
237 |
-
padded_ref_mels = []
|
238 |
-
for mel in ref_mels:
|
239 |
-
padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
|
240 |
-
padded_ref_mels.append(padded_ref_mel)
|
241 |
-
padded_ref_mels = torch.stack(padded_ref_mels)
|
242 |
-
padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
|
243 |
-
return padded_ref_mels
|
244 |
-
|
245 |
-
|
246 |
-
# get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
|
247 |
-
|
248 |
-
def get_inference_prompt(
|
249 |
-
metainfo,
|
250 |
-
speed = 1., tokenizer = "pinyin", polyphone = True,
|
251 |
-
target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
|
252 |
-
use_truth_duration = False,
|
253 |
-
infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
|
254 |
-
):
|
255 |
-
prompts_all = []
|
256 |
-
|
257 |
-
min_tokens = min_secs * target_sample_rate // hop_length
|
258 |
-
max_tokens = max_secs * target_sample_rate // hop_length
|
259 |
-
|
260 |
-
batch_accum = [0] * num_buckets
|
261 |
-
utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
|
262 |
-
([[] for _ in range(num_buckets)] for _ in range(6))
|
263 |
-
|
264 |
-
mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
|
265 |
-
|
266 |
-
for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
|
267 |
-
|
268 |
-
# Audio
|
269 |
-
ref_audio, ref_sr = torchaudio.load(prompt_wav)
|
270 |
-
ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
|
271 |
-
if ref_rms < target_rms:
|
272 |
-
ref_audio = ref_audio * target_rms / ref_rms
|
273 |
-
assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
|
274 |
-
if ref_sr != target_sample_rate:
|
275 |
-
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
276 |
-
ref_audio = resampler(ref_audio)
|
277 |
-
|
278 |
-
# Text
|
279 |
-
if len(prompt_text[-1].encode('utf-8')) == 1:
|
280 |
-
prompt_text = prompt_text + " "
|
281 |
-
text = [prompt_text + gt_text]
|
282 |
-
if tokenizer == "pinyin":
|
283 |
-
text_list = convert_char_to_pinyin(text, polyphone = polyphone)
|
284 |
-
else:
|
285 |
-
text_list = text
|
286 |
-
|
287 |
-
# Duration, mel frame length
|
288 |
-
ref_mel_len = ref_audio.shape[-1] // hop_length
|
289 |
-
if use_truth_duration:
|
290 |
-
gt_audio, gt_sr = torchaudio.load(gt_wav)
|
291 |
-
if gt_sr != target_sample_rate:
|
292 |
-
resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
|
293 |
-
gt_audio = resampler(gt_audio)
|
294 |
-
total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
|
295 |
-
|
296 |
-
# # test vocoder resynthesis
|
297 |
-
# ref_audio = gt_audio
|
298 |
-
else:
|
299 |
-
zh_pause_punc = r"。,、;:?!"
|
300 |
-
ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
|
301 |
-
gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
|
302 |
-
total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
|
303 |
-
|
304 |
-
# to mel spectrogram
|
305 |
-
ref_mel = mel_spectrogram(ref_audio)
|
306 |
-
ref_mel = rearrange(ref_mel, '1 d n -> d n')
|
307 |
-
|
308 |
-
# deal with batch
|
309 |
-
assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
|
310 |
-
assert min_tokens <= total_mel_len <= max_tokens, \
|
311 |
-
f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
|
312 |
-
bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
|
313 |
-
|
314 |
-
utts[bucket_i].append(utt)
|
315 |
-
ref_rms_list[bucket_i].append(ref_rms)
|
316 |
-
ref_mels[bucket_i].append(ref_mel)
|
317 |
-
ref_mel_lens[bucket_i].append(ref_mel_len)
|
318 |
-
total_mel_lens[bucket_i].append(total_mel_len)
|
319 |
-
final_text_list[bucket_i].extend(text_list)
|
320 |
-
|
321 |
-
batch_accum[bucket_i] += total_mel_len
|
322 |
-
|
323 |
-
if batch_accum[bucket_i] >= infer_batch_size:
|
324 |
-
# print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
|
325 |
-
prompts_all.append((
|
326 |
-
utts[bucket_i],
|
327 |
-
ref_rms_list[bucket_i],
|
328 |
-
padded_mel_batch(ref_mels[bucket_i]),
|
329 |
-
ref_mel_lens[bucket_i],
|
330 |
-
total_mel_lens[bucket_i],
|
331 |
-
final_text_list[bucket_i]
|
332 |
-
))
|
333 |
-
batch_accum[bucket_i] = 0
|
334 |
-
utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
|
335 |
-
|
336 |
-
# add residual
|
337 |
-
for bucket_i, bucket_frames in enumerate(batch_accum):
|
338 |
-
if bucket_frames > 0:
|
339 |
-
prompts_all.append((
|
340 |
-
utts[bucket_i],
|
341 |
-
ref_rms_list[bucket_i],
|
342 |
-
padded_mel_batch(ref_mels[bucket_i]),
|
343 |
-
ref_mel_lens[bucket_i],
|
344 |
-
total_mel_lens[bucket_i],
|
345 |
-
final_text_list[bucket_i]
|
346 |
-
))
|
347 |
-
# not only leave easy work for last workers
|
348 |
-
random.seed(666)
|
349 |
-
random.shuffle(prompts_all)
|
350 |
-
|
351 |
-
return prompts_all
|
352 |
-
|
353 |
-
|
354 |
-
# get wav_res_ref_text of seed-tts test metalst
|
355 |
-
# https://github.com/BytedanceSpeech/seed-tts-eval
|
356 |
-
|
357 |
-
def get_seed_tts_test(metalst, gen_wav_dir, gpus):
|
358 |
-
f = open(metalst)
|
359 |
-
lines = f.readlines()
|
360 |
-
f.close()
|
361 |
-
|
362 |
-
test_set_ = []
|
363 |
-
for line in tqdm(lines):
|
364 |
-
if len(line.strip().split('|')) == 5:
|
365 |
-
utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
|
366 |
-
elif len(line.strip().split('|')) == 4:
|
367 |
-
utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
|
368 |
-
|
369 |
-
if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
|
370 |
-
continue
|
371 |
-
gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
|
372 |
-
if not os.path.isabs(prompt_wav):
|
373 |
-
prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
|
374 |
-
|
375 |
-
test_set_.append((gen_wav, prompt_wav, gt_text))
|
376 |
-
|
377 |
-
num_jobs = len(gpus)
|
378 |
-
if num_jobs == 1:
|
379 |
-
return [(gpus[0], test_set_)]
|
380 |
-
|
381 |
-
wav_per_job = len(test_set_) // num_jobs + 1
|
382 |
-
test_set = []
|
383 |
-
for i in range(num_jobs):
|
384 |
-
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
385 |
-
|
386 |
-
return test_set
|
387 |
-
|
388 |
-
|
389 |
-
# get librispeech test-clean cross sentence test
|
390 |
-
|
391 |
-
def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
|
392 |
-
f = open(metalst)
|
393 |
-
lines = f.readlines()
|
394 |
-
f.close()
|
395 |
-
|
396 |
-
test_set_ = []
|
397 |
-
for line in tqdm(lines):
|
398 |
-
ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
|
399 |
-
|
400 |
-
if eval_ground_truth:
|
401 |
-
gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
|
402 |
-
gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
|
403 |
-
else:
|
404 |
-
if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
|
405 |
-
raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
|
406 |
-
gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
|
407 |
-
|
408 |
-
ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
|
409 |
-
ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
|
410 |
-
|
411 |
-
test_set_.append((gen_wav, ref_wav, gen_txt))
|
412 |
-
|
413 |
-
num_jobs = len(gpus)
|
414 |
-
if num_jobs == 1:
|
415 |
-
return [(gpus[0], test_set_)]
|
416 |
-
|
417 |
-
wav_per_job = len(test_set_) // num_jobs + 1
|
418 |
-
test_set = []
|
419 |
-
for i in range(num_jobs):
|
420 |
-
test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
|
421 |
-
|
422 |
-
return test_set
|
423 |
-
|
424 |
-
|
425 |
-
# load asr model
|
426 |
-
|
427 |
-
def load_asr_model(lang, ckpt_dir = ""):
|
428 |
-
if lang == "zh":
|
429 |
-
from funasr import AutoModel
|
430 |
-
model = AutoModel(
|
431 |
-
model = os.path.join(ckpt_dir, "paraformer-zh"),
|
432 |
-
# vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
|
433 |
-
# punc_model = os.path.join(ckpt_dir, "ct-punc"),
|
434 |
-
# spk_model = os.path.join(ckpt_dir, "cam++"),
|
435 |
-
disable_update=True,
|
436 |
-
) # following seed-tts setting
|
437 |
-
elif lang == "en":
|
438 |
-
from faster_whisper import WhisperModel
|
439 |
-
model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
|
440 |
-
model = WhisperModel(model_size, device="cuda", compute_type="float16")
|
441 |
-
return model
|
442 |
-
|
443 |
-
|
444 |
-
# WER Evaluation, the way Seed-TTS does
|
445 |
-
|
446 |
-
def run_asr_wer(args):
|
447 |
-
rank, lang, test_set, ckpt_dir = args
|
448 |
-
|
449 |
-
if lang == "zh":
|
450 |
-
import zhconv
|
451 |
-
torch.cuda.set_device(rank)
|
452 |
-
elif lang == "en":
|
453 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
|
454 |
-
else:
|
455 |
-
raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
|
456 |
-
|
457 |
-
asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
|
458 |
-
|
459 |
-
from zhon.hanzi import punctuation
|
460 |
-
punctuation_all = punctuation + string.punctuation
|
461 |
-
wers = []
|
462 |
-
|
463 |
-
from jiwer import compute_measures
|
464 |
-
for gen_wav, prompt_wav, truth in tqdm(test_set):
|
465 |
-
if lang == "zh":
|
466 |
-
res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
|
467 |
-
hypo = res[0]["text"]
|
468 |
-
hypo = zhconv.convert(hypo, 'zh-cn')
|
469 |
-
elif lang == "en":
|
470 |
-
segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
|
471 |
-
hypo = ''
|
472 |
-
for segment in segments:
|
473 |
-
hypo = hypo + ' ' + segment.text
|
474 |
-
|
475 |
-
# raw_truth = truth
|
476 |
-
# raw_hypo = hypo
|
477 |
-
|
478 |
-
for x in punctuation_all:
|
479 |
-
truth = truth.replace(x, '')
|
480 |
-
hypo = hypo.replace(x, '')
|
481 |
-
|
482 |
-
truth = truth.replace(' ', ' ')
|
483 |
-
hypo = hypo.replace(' ', ' ')
|
484 |
-
|
485 |
-
if lang == "zh":
|
486 |
-
truth = " ".join([x for x in truth])
|
487 |
-
hypo = " ".join([x for x in hypo])
|
488 |
-
elif lang == "en":
|
489 |
-
truth = truth.lower()
|
490 |
-
hypo = hypo.lower()
|
491 |
-
|
492 |
-
measures = compute_measures(truth, hypo)
|
493 |
-
wer = measures["wer"]
|
494 |
-
|
495 |
-
# ref_list = truth.split(" ")
|
496 |
-
# subs = measures["substitutions"] / len(ref_list)
|
497 |
-
# dele = measures["deletions"] / len(ref_list)
|
498 |
-
# inse = measures["insertions"] / len(ref_list)
|
499 |
-
|
500 |
-
wers.append(wer)
|
501 |
-
|
502 |
-
return wers
|
503 |
-
|
504 |
-
|
505 |
-
# SIM Evaluation
|
506 |
-
|
507 |
-
def run_sim(args):
|
508 |
-
rank, test_set, ckpt_dir = args
|
509 |
-
device = f"cuda:{rank}"
|
510 |
-
|
511 |
-
model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
|
512 |
-
state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
|
513 |
-
model.load_state_dict(state_dict['model'], strict=False)
|
514 |
-
|
515 |
-
use_gpu=True if torch.cuda.is_available() else False
|
516 |
-
if use_gpu:
|
517 |
-
model = model.cuda(device)
|
518 |
-
model.eval()
|
519 |
-
|
520 |
-
sim_list = []
|
521 |
-
for wav1, wav2, truth in tqdm(test_set):
|
522 |
-
|
523 |
-
wav1, sr1 = torchaudio.load(wav1)
|
524 |
-
wav2, sr2 = torchaudio.load(wav2)
|
525 |
-
|
526 |
-
resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
|
527 |
-
resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
|
528 |
-
wav1 = resample1(wav1)
|
529 |
-
wav2 = resample2(wav2)
|
530 |
-
|
531 |
-
if use_gpu:
|
532 |
-
wav1 = wav1.cuda(device)
|
533 |
-
wav2 = wav2.cuda(device)
|
534 |
-
with torch.no_grad():
|
535 |
-
emb1 = model(wav1)
|
536 |
-
emb2 = model(wav2)
|
537 |
-
|
538 |
-
sim = F.cosine_similarity(emb1, emb2)[0].item()
|
539 |
-
# print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
|
540 |
-
sim_list.append(sim)
|
541 |
-
|
542 |
-
return sim_list
|
543 |
-
|
544 |
-
|
545 |
-
# filter func for dirty data with many repetitions
|
546 |
-
|
547 |
-
def repetition_found(text, length = 2, tolerance = 10):
|
548 |
-
pattern_count = defaultdict(int)
|
549 |
-
for i in range(len(text) - length + 1):
|
550 |
-
pattern = text[i:i + length]
|
551 |
-
pattern_count[pattern] += 1
|
552 |
-
for pattern, count in pattern_count.items():
|
553 |
-
if count > tolerance:
|
554 |
-
return True
|
555 |
-
return False
|
556 |
-
|
557 |
-
|
558 |
-
# load model checkpoint for inference
|
559 |
-
|
560 |
-
def load_checkpoint(model, ckpt_path, device, use_ema = True):
|
561 |
-
from ema_pytorch import EMA
|
562 |
-
|
563 |
-
ckpt_type = ckpt_path.split(".")[-1]
|
564 |
-
if ckpt_type == "safetensors":
|
565 |
-
from safetensors.torch import load_file
|
566 |
-
checkpoint = load_file(ckpt_path, device=device)
|
567 |
-
else:
|
568 |
-
checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
|
569 |
-
|
570 |
-
if use_ema == True:
|
571 |
-
ema_model = EMA(model, include_online_model = False).to(device)
|
572 |
-
if ckpt_type == "safetensors":
|
573 |
-
ema_model.load_state_dict(checkpoint)
|
574 |
-
else:
|
575 |
-
ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
|
576 |
-
ema_model.copy_params_from_ema_to_model()
|
577 |
-
else:
|
578 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
579 |
-
|
580 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|