Staticaliza commited on
Commit
76f383f
1 Parent(s): 4c78f4a

Upload 7 files

Browse files
model/model___init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from model.cfm import CFM
2
+
3
+ from model.backbones.unett import UNetT
4
+ from model.backbones.dit import DiT
5
+ from model.backbones.mmdit import MMDiT
6
+
7
+ from model.trainer import Trainer
model/model_cfm.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Callable
12
+ from random import random
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+
19
+ from torchdiffeq import odeint
20
+
21
+ from einops import rearrange
22
+
23
+ from model.modules import MelSpec
24
+
25
+ from model.utils import (
26
+ default, exists,
27
+ list_str_to_idx, list_str_to_tensor,
28
+ lens_to_mask, mask_from_frac_lengths,
29
+ )
30
+
31
+
32
+ class CFM(nn.Module):
33
+ def __init__(
34
+ self,
35
+ transformer: nn.Module,
36
+ sigma = 0.,
37
+ odeint_kwargs: dict = dict(
38
+ # atol = 1e-5,
39
+ # rtol = 1e-5,
40
+ method = 'euler' # 'midpoint'
41
+ ),
42
+ audio_drop_prob = 0.3,
43
+ cond_drop_prob = 0.2,
44
+ num_channels = None,
45
+ mel_spec_module: nn.Module | None = None,
46
+ mel_spec_kwargs: dict = dict(),
47
+ frac_lengths_mask: tuple[float, float] = (0.7, 1.),
48
+ vocab_char_map: dict[str: int] | None = None
49
+ ):
50
+ super().__init__()
51
+
52
+ self.frac_lengths_mask = frac_lengths_mask
53
+
54
+ # mel spec
55
+ self.mel_spec = default(mel_spec_module, MelSpec(**mel_spec_kwargs))
56
+ num_channels = default(num_channels, self.mel_spec.n_mel_channels)
57
+ self.num_channels = num_channels
58
+
59
+ # classifier-free guidance
60
+ self.audio_drop_prob = audio_drop_prob
61
+ self.cond_drop_prob = cond_drop_prob
62
+
63
+ # transformer
64
+ self.transformer = transformer
65
+ dim = transformer.dim
66
+ self.dim = dim
67
+
68
+ # conditional flow related
69
+ self.sigma = sigma
70
+
71
+ # sampling related
72
+ self.odeint_kwargs = odeint_kwargs
73
+
74
+ # vocab map for tokenization
75
+ self.vocab_char_map = vocab_char_map
76
+
77
+ @property
78
+ def device(self):
79
+ return next(self.parameters()).device
80
+
81
+ @torch.no_grad()
82
+ def sample(
83
+ self,
84
+ cond: float['b n d'] | float['b nw'],
85
+ text: int['b nt'] | list[str],
86
+ duration: int | int['b'],
87
+ *,
88
+ lens: int['b'] | None = None,
89
+ steps = 32,
90
+ cfg_strength = 1.,
91
+ sway_sampling_coef = None,
92
+ seed: int | None = None,
93
+ max_duration = 4096,
94
+ vocoder: Callable[[float['b d n']], float['b nw']] | None = None,
95
+ no_ref_audio = False,
96
+ duplicate_test = False,
97
+ t_inter = 0.1,
98
+ edit_mask = None,
99
+ ):
100
+ self.eval()
101
+
102
+ # raw wave
103
+
104
+ if cond.ndim == 2:
105
+ cond = self.mel_spec(cond)
106
+ cond = rearrange(cond, 'b d n -> b n d')
107
+ assert cond.shape[-1] == self.num_channels
108
+
109
+ batch, cond_seq_len, device = *cond.shape[:2], cond.device
110
+ if not exists(lens):
111
+ lens = torch.full((batch,), cond_seq_len, device = device, dtype = torch.long)
112
+
113
+ # text
114
+
115
+ if isinstance(text, list):
116
+ if exists(self.vocab_char_map):
117
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
118
+ else:
119
+ text = list_str_to_tensor(text).to(device)
120
+ assert text.shape[0] == batch
121
+
122
+ if exists(text):
123
+ text_lens = (text != -1).sum(dim = -1)
124
+ lens = torch.maximum(text_lens, lens) # make sure lengths are at least those of the text characters
125
+
126
+ # duration
127
+
128
+ cond_mask = lens_to_mask(lens)
129
+ if edit_mask is not None:
130
+ cond_mask = cond_mask & edit_mask
131
+
132
+ if isinstance(duration, int):
133
+ duration = torch.full((batch,), duration, device = device, dtype = torch.long)
134
+
135
+ duration = torch.maximum(lens + 1, duration) # just add one token so something is generated
136
+ duration = duration.clamp(max = max_duration)
137
+ max_duration = duration.amax()
138
+
139
+ # duplicate test corner for inner time step oberservation
140
+ if duplicate_test:
141
+ test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2*cond_seq_len), value = 0.)
142
+
143
+ cond = F.pad(cond, (0, 0, 0, max_duration - cond_seq_len), value = 0.)
144
+ cond_mask = F.pad(cond_mask, (0, max_duration - cond_mask.shape[-1]), value = False)
145
+ cond_mask = rearrange(cond_mask, '... -> ... 1')
146
+ step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond)) # allow direct control (cut cond audio) with lens passed in
147
+
148
+ if batch > 1:
149
+ mask = lens_to_mask(duration)
150
+ else: # save memory and speed up, as single inference need no mask currently
151
+ mask = None
152
+
153
+ # test for no ref audio
154
+ if no_ref_audio:
155
+ cond = torch.zeros_like(cond)
156
+
157
+ # neural ode
158
+
159
+ def fn(t, x):
160
+ # at each step, conditioning is fixed
161
+ # step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
162
+
163
+ # predict flow
164
+ pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = False, drop_text = False)
165
+ if cfg_strength < 1e-5:
166
+ return pred
167
+
168
+ null_pred = self.transformer(x = x, cond = step_cond, text = text, time = t, mask = mask, drop_audio_cond = True, drop_text = True)
169
+ return pred + (pred - null_pred) * cfg_strength
170
+
171
+ # noise input
172
+ # to make sure batch inference result is same with different batch size, and for sure single inference
173
+ # still some difference maybe due to convolutional layers
174
+ y0 = []
175
+ for dur in duration:
176
+ if exists(seed):
177
+ torch.manual_seed(seed)
178
+ y0.append(torch.randn(dur, self.num_channels, device = self.device))
179
+ y0 = pad_sequence(y0, padding_value = 0, batch_first = True)
180
+
181
+ t_start = 0
182
+
183
+ # duplicate test corner for inner time step oberservation
184
+ if duplicate_test:
185
+ t_start = t_inter
186
+ y0 = (1 - t_start) * y0 + t_start * test_cond
187
+ steps = int(steps * (1 - t_start))
188
+
189
+ t = torch.linspace(t_start, 1, steps, device = self.device)
190
+ if sway_sampling_coef is not None:
191
+ t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
192
+
193
+ trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
194
+
195
+ sampled = trajectory[-1]
196
+ out = sampled
197
+ out = torch.where(cond_mask, cond, out)
198
+
199
+ if exists(vocoder):
200
+ out = rearrange(out, 'b n d -> b d n')
201
+ out = vocoder(out)
202
+
203
+ return out, trajectory
204
+
205
+ def forward(
206
+ self,
207
+ inp: float['b n d'] | float['b nw'], # mel or raw wave
208
+ text: int['b nt'] | list[str],
209
+ *,
210
+ lens: int['b'] | None = None,
211
+ noise_scheduler: str | None = None,
212
+ ):
213
+ # handle raw wave
214
+ if inp.ndim == 2:
215
+ inp = self.mel_spec(inp)
216
+ inp = rearrange(inp, 'b d n -> b n d')
217
+ assert inp.shape[-1] == self.num_channels
218
+
219
+ batch, seq_len, dtype, device, σ1 = *inp.shape[:2], inp.dtype, self.device, self.sigma
220
+
221
+ # handle text as string
222
+ if isinstance(text, list):
223
+ if exists(self.vocab_char_map):
224
+ text = list_str_to_idx(text, self.vocab_char_map).to(device)
225
+ else:
226
+ text = list_str_to_tensor(text).to(device)
227
+ assert text.shape[0] == batch
228
+
229
+ # lens and mask
230
+ if not exists(lens):
231
+ lens = torch.full((batch,), seq_len, device = device)
232
+
233
+ mask = lens_to_mask(lens, length = seq_len) # useless here, as collate_fn will pad to max length in batch
234
+
235
+ # get a random span to mask out for training conditionally
236
+ frac_lengths = torch.zeros((batch,), device = self.device).float().uniform_(*self.frac_lengths_mask)
237
+ rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
238
+
239
+ if exists(mask):
240
+ rand_span_mask &= mask
241
+
242
+ # mel is x1
243
+ x1 = inp
244
+
245
+ # x0 is gaussian noise
246
+ x0 = torch.randn_like(x1)
247
+
248
+ # time step
249
+ time = torch.rand((batch,), dtype = dtype, device = self.device)
250
+ # TODO. noise_scheduler
251
+
252
+ # sample xt (φ_t(x) in the paper)
253
+ t = rearrange(time, 'b -> b 1 1')
254
+ φ = (1 - t) * x0 + t * x1
255
+ flow = x1 - x0
256
+
257
+ # only predict what is within the random mask span for infilling
258
+ cond = torch.where(
259
+ rand_span_mask[..., None],
260
+ torch.zeros_like(x1), x1
261
+ )
262
+
263
+ # transformer and cfg training with a drop rate
264
+ drop_audio_cond = random() < self.audio_drop_prob # p_drop in voicebox paper
265
+ if random() < self.cond_drop_prob: # p_uncond in voicebox paper
266
+ drop_audio_cond = True
267
+ drop_text = True
268
+ else:
269
+ drop_text = False
270
+
271
+ # if want rigourously mask out padding, record in collate_fn in dataset.py, and pass in here
272
+ # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
273
+ pred = self.transformer(x = φ, cond = cond, text = text, time = time, drop_audio_cond = drop_audio_cond, drop_text = drop_text)
274
+
275
+ # flow matching loss
276
+ loss = F.mse_loss(pred, flow, reduction = 'none')
277
+ loss = loss[rand_span_mask]
278
+
279
+ return loss.mean(), cond, pred
model/model_dataset.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import Dataset, Sampler
8
+ import torchaudio
9
+ from datasets import load_dataset, load_from_disk
10
+ from datasets import Dataset as Dataset_
11
+
12
+ from einops import rearrange
13
+
14
+ from model.modules import MelSpec
15
+
16
+
17
+ class HFDataset(Dataset):
18
+ def __init__(
19
+ self,
20
+ hf_dataset: Dataset,
21
+ target_sample_rate = 24_000,
22
+ n_mel_channels = 100,
23
+ hop_length = 256,
24
+ ):
25
+ self.data = hf_dataset
26
+ self.target_sample_rate = target_sample_rate
27
+ self.hop_length = hop_length
28
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
29
+
30
+ def get_frame_len(self, index):
31
+ row = self.data[index]
32
+ audio = row['audio']['array']
33
+ sample_rate = row['audio']['sampling_rate']
34
+ return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, index):
40
+ row = self.data[index]
41
+ audio = row['audio']['array']
42
+
43
+ # logger.info(f"Audio shape: {audio.shape}")
44
+
45
+ sample_rate = row['audio']['sampling_rate']
46
+ duration = audio.shape[-1] / sample_rate
47
+
48
+ if duration > 30 or duration < 0.3:
49
+ return self.__getitem__((index + 1) % len(self.data))
50
+
51
+ audio_tensor = torch.from_numpy(audio).float()
52
+
53
+ if sample_rate != self.target_sample_rate:
54
+ resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
55
+ audio_tensor = resampler(audio_tensor)
56
+
57
+ audio_tensor = rearrange(audio_tensor, 't -> 1 t')
58
+
59
+ mel_spec = self.mel_spectrogram(audio_tensor)
60
+
61
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
62
+
63
+ text = row['text']
64
+
65
+ return dict(
66
+ mel_spec = mel_spec,
67
+ text = text,
68
+ )
69
+
70
+
71
+ class CustomDataset(Dataset):
72
+ def __init__(
73
+ self,
74
+ custom_dataset: Dataset,
75
+ durations = None,
76
+ target_sample_rate = 24_000,
77
+ hop_length = 256,
78
+ n_mel_channels = 100,
79
+ preprocessed_mel = False,
80
+ ):
81
+ self.data = custom_dataset
82
+ self.durations = durations
83
+ self.target_sample_rate = target_sample_rate
84
+ self.hop_length = hop_length
85
+ self.preprocessed_mel = preprocessed_mel
86
+ if not preprocessed_mel:
87
+ self.mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, hop_length=hop_length, n_mel_channels=n_mel_channels)
88
+
89
+ def get_frame_len(self, index):
90
+ if self.durations is not None: # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
91
+ return self.durations[index] * self.target_sample_rate / self.hop_length
92
+ return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
93
+
94
+ def __len__(self):
95
+ return len(self.data)
96
+
97
+ def __getitem__(self, index):
98
+ row = self.data[index]
99
+ audio_path = row["audio_path"]
100
+ text = row["text"]
101
+ duration = row["duration"]
102
+
103
+ if self.preprocessed_mel:
104
+ mel_spec = torch.tensor(row["mel_spec"])
105
+
106
+ else:
107
+ audio, source_sample_rate = torchaudio.load(audio_path)
108
+
109
+ if duration > 30 or duration < 0.3:
110
+ return self.__getitem__((index + 1) % len(self.data))
111
+
112
+ if source_sample_rate != self.target_sample_rate:
113
+ resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
114
+ audio = resampler(audio)
115
+
116
+ mel_spec = self.mel_spectrogram(audio)
117
+ mel_spec = rearrange(mel_spec, '1 d t -> d t')
118
+
119
+ return dict(
120
+ mel_spec = mel_spec,
121
+ text = text,
122
+ )
123
+
124
+
125
+ # Dynamic Batch Sampler
126
+
127
+ class DynamicBatchSampler(Sampler[list[int]]):
128
+ """ Extension of Sampler that will do the following:
129
+ 1. Change the batch size (essentially number of sequences)
130
+ in a batch to ensure that the total number of frames are less
131
+ than a certain threshold.
132
+ 2. Make sure the padding efficiency in the batch is high.
133
+ """
134
+
135
+ def __init__(self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False):
136
+ self.sampler = sampler
137
+ self.frames_threshold = frames_threshold
138
+ self.max_samples = max_samples
139
+
140
+ indices, batches = [], []
141
+ data_source = self.sampler.data_source
142
+
143
+ for idx in tqdm(self.sampler, desc=f"Sorting with sampler... if slow, check whether dataset is provided with duration"):
144
+ indices.append((idx, data_source.get_frame_len(idx)))
145
+ indices.sort(key=lambda elem : elem[1])
146
+
147
+ batch = []
148
+ batch_frames = 0
149
+ for idx, frame_len in tqdm(indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"):
150
+ if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
151
+ batch.append(idx)
152
+ batch_frames += frame_len
153
+ else:
154
+ if len(batch) > 0:
155
+ batches.append(batch)
156
+ if frame_len <= self.frames_threshold:
157
+ batch = [idx]
158
+ batch_frames = frame_len
159
+ else:
160
+ batch = []
161
+ batch_frames = 0
162
+
163
+ if not drop_last and len(batch) > 0:
164
+ batches.append(batch)
165
+
166
+ del indices
167
+
168
+ # if want to have different batches between epochs, may just set a seed and log it in ckpt
169
+ # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
170
+ # e.g. for epoch n, use (random_seed + n)
171
+ random.seed(random_seed)
172
+ random.shuffle(batches)
173
+
174
+ self.batches = batches
175
+
176
+ def __iter__(self):
177
+ return iter(self.batches)
178
+
179
+ def __len__(self):
180
+ return len(self.batches)
181
+
182
+
183
+ # Load dataset
184
+
185
+ def load_dataset(
186
+ dataset_name: str,
187
+ tokenizer: str = "pinyin",
188
+ dataset_type: str = "CustomDataset",
189
+ audio_type: str = "raw",
190
+ mel_spec_kwargs: dict = dict()
191
+ ) -> CustomDataset | HFDataset:
192
+ '''
193
+ dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
194
+ - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
195
+ '''
196
+
197
+ print("Loading dataset ...")
198
+
199
+ if dataset_type == "CustomDataset":
200
+ if audio_type == "raw":
201
+ try:
202
+ train_dataset = load_from_disk(f"data/{dataset_name}_{tokenizer}/raw")
203
+ except:
204
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/raw.arrow")
205
+ preprocessed_mel = False
206
+ elif audio_type == "mel":
207
+ train_dataset = Dataset_.from_file(f"data/{dataset_name}_{tokenizer}/mel.arrow")
208
+ preprocessed_mel = True
209
+ with open(f"data/{dataset_name}_{tokenizer}/duration.json", 'r', encoding='utf-8') as f:
210
+ data_dict = json.load(f)
211
+ durations = data_dict["duration"]
212
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
213
+
214
+ elif dataset_type == "CustomDatasetPath":
215
+ try:
216
+ train_dataset = load_from_disk(f"{dataset_name}/raw")
217
+ except:
218
+ train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
219
+
220
+ with open(f"{dataset_name}/duration.json", 'r', encoding='utf-8') as f:
221
+ data_dict = json.load(f)
222
+ durations = data_dict["duration"]
223
+ train_dataset = CustomDataset(train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs)
224
+
225
+ elif dataset_type == "HFDataset":
226
+ print("Should manually modify the path of huggingface dataset to your need.\n" +
227
+ "May also the corresponding script cuz different dataset may have different format.")
228
+ pre, post = dataset_name.split("_")
229
+ train_dataset = HFDataset(load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir="./data"),)
230
+
231
+ return train_dataset
232
+
233
+
234
+ # collation
235
+
236
+ def collate_fn(batch):
237
+ mel_specs = [item['mel_spec'].squeeze(0) for item in batch]
238
+ mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
239
+ max_mel_length = mel_lengths.amax()
240
+
241
+ padded_mel_specs = []
242
+ for spec in mel_specs: # TODO. maybe records mask for attention here
243
+ padding = (0, max_mel_length - spec.size(-1))
244
+ padded_spec = F.pad(spec, padding, value = 0)
245
+ padded_mel_specs.append(padded_spec)
246
+
247
+ mel_specs = torch.stack(padded_mel_specs)
248
+
249
+ text = [item['text'] for item in batch]
250
+ text_lengths = torch.LongTensor([len(item) for item in text])
251
+
252
+ return dict(
253
+ mel = mel_specs,
254
+ mel_lengths = mel_lengths,
255
+ text = text,
256
+ text_lengths = text_lengths,
257
+ )
model/model_ecapa_tdnn.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ ''' Res2Conv1d + BatchNorm1d + ReLU
13
+ '''
14
+
15
+ class Res2Conv1dReluBn(nn.Module):
16
+ '''
17
+ in_channels == out_channels == channels
18
+ '''
19
+
20
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
21
+ super().__init__()
22
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
23
+ self.scale = scale
24
+ self.width = channels // scale
25
+ self.nums = scale if scale == 1 else scale - 1
26
+
27
+ self.convs = []
28
+ self.bns = []
29
+ for i in range(self.nums):
30
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
31
+ self.bns.append(nn.BatchNorm1d(self.width))
32
+ self.convs = nn.ModuleList(self.convs)
33
+ self.bns = nn.ModuleList(self.bns)
34
+
35
+ def forward(self, x):
36
+ out = []
37
+ spx = torch.split(x, self.width, 1)
38
+ for i in range(self.nums):
39
+ if i == 0:
40
+ sp = spx[i]
41
+ else:
42
+ sp = sp + spx[i]
43
+ # Order: conv -> relu -> bn
44
+ sp = self.convs[i](sp)
45
+ sp = self.bns[i](F.relu(sp))
46
+ out.append(sp)
47
+ if self.scale != 1:
48
+ out.append(spx[self.nums])
49
+ out = torch.cat(out, dim=1)
50
+
51
+ return out
52
+
53
+
54
+ ''' Conv1d + BatchNorm1d + ReLU
55
+ '''
56
+
57
+ class Conv1dReluBn(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
59
+ super().__init__()
60
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
61
+ self.bn = nn.BatchNorm1d(out_channels)
62
+
63
+ def forward(self, x):
64
+ return self.bn(F.relu(self.conv(x)))
65
+
66
+
67
+ ''' The SE connection of 1D case.
68
+ '''
69
+
70
+ class SE_Connect(nn.Module):
71
+ def __init__(self, channels, se_bottleneck_dim=128):
72
+ super().__init__()
73
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
74
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
75
+
76
+ def forward(self, x):
77
+ out = x.mean(dim=2)
78
+ out = F.relu(self.linear1(out))
79
+ out = torch.sigmoid(self.linear2(out))
80
+ out = x * out.unsqueeze(2)
81
+
82
+ return out
83
+
84
+
85
+ ''' SE-Res2Block of the ECAPA-TDNN architecture.
86
+ '''
87
+
88
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
89
+ # return nn.Sequential(
90
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
91
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
92
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
93
+ # SE_Connect(channels)
94
+ # )
95
+
96
+ class SE_Res2Block(nn.Module):
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
98
+ super().__init__()
99
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
100
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
101
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
102
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
103
+
104
+ self.shortcut = None
105
+ if in_channels != out_channels:
106
+ self.shortcut = nn.Conv1d(
107
+ in_channels=in_channels,
108
+ out_channels=out_channels,
109
+ kernel_size=1,
110
+ )
111
+
112
+ def forward(self, x):
113
+ residual = x
114
+ if self.shortcut:
115
+ residual = self.shortcut(x)
116
+
117
+ x = self.Conv1dReluBn1(x)
118
+ x = self.Res2Conv1dReluBn(x)
119
+ x = self.Conv1dReluBn2(x)
120
+ x = self.SE_Connect(x)
121
+
122
+ return x + residual
123
+
124
+
125
+ ''' Attentive weighted mean and standard deviation pooling.
126
+ '''
127
+
128
+ class AttentiveStatsPool(nn.Module):
129
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
130
+ super().__init__()
131
+ self.global_context_att = global_context_att
132
+
133
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
134
+ if global_context_att:
135
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
136
+ else:
137
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
138
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
139
+
140
+ def forward(self, x):
141
+
142
+ if self.global_context_att:
143
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
144
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
145
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
146
+ else:
147
+ x_in = x
148
+
149
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
150
+ alpha = torch.tanh(self.linear1(x_in))
151
+ # alpha = F.relu(self.linear1(x_in))
152
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
153
+ mean = torch.sum(alpha * x, dim=2)
154
+ residuals = torch.sum(alpha * (x ** 2), dim=2) - mean ** 2
155
+ std = torch.sqrt(residuals.clamp(min=1e-9))
156
+ return torch.cat([mean, std], dim=1)
157
+
158
+
159
+ class ECAPA_TDNN(nn.Module):
160
+ def __init__(self, feat_dim=80, channels=512, emb_dim=192, global_context_att=False,
161
+ feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
162
+ super().__init__()
163
+
164
+ self.feat_type = feat_type
165
+ self.feature_selection = feature_selection
166
+ self.update_extract = update_extract
167
+ self.sr = sr
168
+
169
+ torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
170
+ try:
171
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
172
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source='local', config_path=config_path)
173
+ except:
174
+ self.feature_extract = torch.hub.load('s3prl/s3prl', feat_type)
175
+
176
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"):
177
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
178
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"):
179
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
180
+
181
+ self.feat_num = self.get_feat_num()
182
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
183
+
184
+ if feat_type != 'fbank' and feat_type != 'mfcc':
185
+ freeze_list = ['final_proj', 'label_embs_concat', 'mask_emb', 'project_q', 'quantizer']
186
+ for name, param in self.feature_extract.named_parameters():
187
+ for freeze_val in freeze_list:
188
+ if freeze_val in name:
189
+ param.requires_grad = False
190
+ break
191
+
192
+ if not self.update_extract:
193
+ for param in self.feature_extract.parameters():
194
+ param.requires_grad = False
195
+
196
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
197
+ # self.channels = [channels] * 4 + [channels * 3]
198
+ self.channels = [channels] * 4 + [1536]
199
+
200
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
201
+ self.layer2 = SE_Res2Block(self.channels[0], self.channels[1], kernel_size=3, stride=1, padding=2, dilation=2, scale=8, se_bottleneck_dim=128)
202
+ self.layer3 = SE_Res2Block(self.channels[1], self.channels[2], kernel_size=3, stride=1, padding=3, dilation=3, scale=8, se_bottleneck_dim=128)
203
+ self.layer4 = SE_Res2Block(self.channels[2], self.channels[3], kernel_size=3, stride=1, padding=4, dilation=4, scale=8, se_bottleneck_dim=128)
204
+
205
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
206
+ cat_channels = channels * 3
207
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
208
+ self.pooling = AttentiveStatsPool(self.channels[-1], attention_channels=128, global_context_att=global_context_att)
209
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
210
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
211
+
212
+
213
+ def get_feat_num(self):
214
+ self.feature_extract.eval()
215
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
216
+ with torch.no_grad():
217
+ features = self.feature_extract(wav)
218
+ select_feature = features[self.feature_selection]
219
+ if isinstance(select_feature, (list, tuple)):
220
+ return len(select_feature)
221
+ else:
222
+ return 1
223
+
224
+ def get_feat(self, x):
225
+ if self.update_extract:
226
+ x = self.feature_extract([sample for sample in x])
227
+ else:
228
+ with torch.no_grad():
229
+ if self.feat_type == 'fbank' or self.feat_type == 'mfcc':
230
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
231
+ else:
232
+ x = self.feature_extract([sample for sample in x])
233
+
234
+ if self.feat_type == 'fbank':
235
+ x = x.log()
236
+
237
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
238
+ x = x[self.feature_selection]
239
+ if isinstance(x, (list, tuple)):
240
+ x = torch.stack(x, dim=0)
241
+ else:
242
+ x = x.unsqueeze(0)
243
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
244
+ x = (norm_weights * x).sum(dim=0)
245
+ x = torch.transpose(x, 1, 2) + 1e-6
246
+
247
+ x = self.instance_norm(x)
248
+ return x
249
+
250
+ def forward(self, x):
251
+ x = self.get_feat(x)
252
+
253
+ out1 = self.layer1(x)
254
+ out2 = self.layer2(out1)
255
+ out3 = self.layer3(out2)
256
+ out4 = self.layer4(out3)
257
+
258
+ out = torch.cat([out2, out3, out4], dim=1)
259
+ out = F.relu(self.conv(out))
260
+ out = self.bn(self.pooling(out))
261
+ out = self.linear(out)
262
+
263
+ return out
264
+
265
+
266
+ def ECAPA_TDNN_SMALL(feat_dim, emb_dim=256, feat_type='wavlm_large', sr=16000, feature_selection="hidden_states", update_extract=False, config_path=None):
267
+ return ECAPA_TDNN(feat_dim=feat_dim, channels=512, emb_dim=emb_dim,
268
+ feat_type=feat_type, sr=sr, feature_selection=feature_selection, update_extract=update_extract, config_path=config_path)
model/model_modules.py ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+ from typing import Optional
12
+ import math
13
+
14
+ import torch
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+ import torchaudio
18
+
19
+ from einops import rearrange
20
+ from x_transformers.x_transformers import apply_rotary_pos_emb
21
+
22
+
23
+ # raw wav to mel spec
24
+
25
+ class MelSpec(nn.Module):
26
+ def __init__(
27
+ self,
28
+ filter_length = 1024,
29
+ hop_length = 256,
30
+ win_length = 1024,
31
+ n_mel_channels = 100,
32
+ target_sample_rate = 24_000,
33
+ normalize = False,
34
+ power = 1,
35
+ norm = None,
36
+ center = True,
37
+ ):
38
+ super().__init__()
39
+ self.n_mel_channels = n_mel_channels
40
+
41
+ self.mel_stft = torchaudio.transforms.MelSpectrogram(
42
+ sample_rate = target_sample_rate,
43
+ n_fft = filter_length,
44
+ win_length = win_length,
45
+ hop_length = hop_length,
46
+ n_mels = n_mel_channels,
47
+ power = power,
48
+ center = center,
49
+ normalized = normalize,
50
+ norm = norm,
51
+ )
52
+
53
+ self.register_buffer('dummy', torch.tensor(0), persistent = False)
54
+
55
+ def forward(self, inp):
56
+ if len(inp.shape) == 3:
57
+ inp = rearrange(inp, 'b 1 nw -> b nw')
58
+
59
+ assert len(inp.shape) == 2
60
+
61
+ if self.dummy.device != inp.device:
62
+ self.to(inp.device)
63
+
64
+ mel = self.mel_stft(inp)
65
+ mel = mel.clamp(min = 1e-5).log()
66
+ return mel
67
+
68
+
69
+ # sinusoidal position embedding
70
+
71
+ class SinusPositionEmbedding(nn.Module):
72
+ def __init__(self, dim):
73
+ super().__init__()
74
+ self.dim = dim
75
+
76
+ def forward(self, x, scale=1000):
77
+ device = x.device
78
+ half_dim = self.dim // 2
79
+ emb = math.log(10000) / (half_dim - 1)
80
+ emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
81
+ emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
82
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
83
+ return emb
84
+
85
+
86
+ # convolutional position embedding
87
+
88
+ class ConvPositionEmbedding(nn.Module):
89
+ def __init__(self, dim, kernel_size = 31, groups = 16):
90
+ super().__init__()
91
+ assert kernel_size % 2 != 0
92
+ self.conv1d = nn.Sequential(
93
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
94
+ nn.Mish(),
95
+ nn.Conv1d(dim, dim, kernel_size, groups = groups, padding = kernel_size // 2),
96
+ nn.Mish(),
97
+ )
98
+
99
+ def forward(self, x: float['b n d'], mask: bool['b n'] | None = None):
100
+ if mask is not None:
101
+ mask = mask[..., None]
102
+ x = x.masked_fill(~mask, 0.)
103
+
104
+ x = rearrange(x, 'b n d -> b d n')
105
+ x = self.conv1d(x)
106
+ out = rearrange(x, 'b d n -> b n d')
107
+
108
+ if mask is not None:
109
+ out = out.masked_fill(~mask, 0.)
110
+
111
+ return out
112
+
113
+
114
+ # rotary positional embedding related
115
+
116
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.):
117
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
118
+ # has some connection to NTK literature
119
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
120
+ # https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
121
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
122
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
123
+ t = torch.arange(end, device=freqs.device) # type: ignore
124
+ freqs = torch.outer(t, freqs).float() # type: ignore
125
+ freqs_cos = torch.cos(freqs) # real part
126
+ freqs_sin = torch.sin(freqs) # imaginary part
127
+ return torch.cat([freqs_cos, freqs_sin], dim=-1)
128
+
129
+ def get_pos_embed_indices(start, length, max_pos, scale=1.):
130
+ # length = length if isinstance(length, int) else length.max()
131
+ scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
132
+ pos = start.unsqueeze(1) + (
133
+ torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) *
134
+ scale.unsqueeze(1)).long()
135
+ # avoid extra long error.
136
+ pos = torch.where(pos < max_pos, pos, max_pos - 1)
137
+ return pos
138
+
139
+
140
+ # Global Response Normalization layer (Instance Normalization ?)
141
+
142
+ class GRN(nn.Module):
143
+ def __init__(self, dim):
144
+ super().__init__()
145
+ self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
146
+ self.beta = nn.Parameter(torch.zeros(1, 1, dim))
147
+
148
+ def forward(self, x):
149
+ Gx = torch.norm(x, p=2, dim=1, keepdim=True)
150
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
151
+ return self.gamma * (x * Nx) + self.beta + x
152
+
153
+
154
+ # ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
155
+ # ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
156
+
157
+ class ConvNeXtV2Block(nn.Module):
158
+ def __init__(
159
+ self,
160
+ dim: int,
161
+ intermediate_dim: int,
162
+ dilation: int = 1,
163
+ ):
164
+ super().__init__()
165
+ padding = (dilation * (7 - 1)) // 2
166
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation) # depthwise conv
167
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
168
+ self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
169
+ self.act = nn.GELU()
170
+ self.grn = GRN(intermediate_dim)
171
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
172
+
173
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
174
+ residual = x
175
+ x = x.transpose(1, 2) # b n d -> b d n
176
+ x = self.dwconv(x)
177
+ x = x.transpose(1, 2) # b d n -> b n d
178
+ x = self.norm(x)
179
+ x = self.pwconv1(x)
180
+ x = self.act(x)
181
+ x = self.grn(x)
182
+ x = self.pwconv2(x)
183
+ return residual + x
184
+
185
+
186
+ # AdaLayerNormZero
187
+ # return with modulated x for attn input, and params for later mlp modulation
188
+
189
+ class AdaLayerNormZero(nn.Module):
190
+ def __init__(self, dim):
191
+ super().__init__()
192
+
193
+ self.silu = nn.SiLU()
194
+ self.linear = nn.Linear(dim, dim * 6)
195
+
196
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
197
+
198
+ def forward(self, x, emb = None):
199
+ emb = self.linear(self.silu(emb))
200
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
201
+
202
+ x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
203
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
204
+
205
+
206
+ # AdaLayerNormZero for final layer
207
+ # return only with modulated x for attn input, cuz no more mlp modulation
208
+
209
+ class AdaLayerNormZero_Final(nn.Module):
210
+ def __init__(self, dim):
211
+ super().__init__()
212
+
213
+ self.silu = nn.SiLU()
214
+ self.linear = nn.Linear(dim, dim * 2)
215
+
216
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
217
+
218
+ def forward(self, x, emb):
219
+ emb = self.linear(self.silu(emb))
220
+ scale, shift = torch.chunk(emb, 2, dim=1)
221
+
222
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
223
+ return x
224
+
225
+
226
+ # FeedForward
227
+
228
+ class FeedForward(nn.Module):
229
+ def __init__(self, dim, dim_out = None, mult = 4, dropout = 0., approximate: str = 'none'):
230
+ super().__init__()
231
+ inner_dim = int(dim * mult)
232
+ dim_out = dim_out if dim_out is not None else dim
233
+
234
+ activation = nn.GELU(approximate=approximate)
235
+ project_in = nn.Sequential(
236
+ nn.Linear(dim, inner_dim),
237
+ activation
238
+ )
239
+ self.ff = nn.Sequential(
240
+ project_in,
241
+ nn.Dropout(dropout),
242
+ nn.Linear(inner_dim, dim_out)
243
+ )
244
+
245
+ def forward(self, x):
246
+ return self.ff(x)
247
+
248
+
249
+ # Attention with possible joint part
250
+ # modified from diffusers/src/diffusers/models/attention_processor.py
251
+
252
+ class Attention(nn.Module):
253
+ def __init__(
254
+ self,
255
+ processor: JointAttnProcessor | AttnProcessor,
256
+ dim: int,
257
+ heads: int = 8,
258
+ dim_head: int = 64,
259
+ dropout: float = 0.0,
260
+ context_dim: Optional[int] = None, # if not None -> joint attention
261
+ context_pre_only = None,
262
+ ):
263
+ super().__init__()
264
+
265
+ if not hasattr(F, "scaled_dot_product_attention"):
266
+ raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
267
+
268
+ self.processor = processor
269
+
270
+ self.dim = dim
271
+ self.heads = heads
272
+ self.inner_dim = dim_head * heads
273
+ self.dropout = dropout
274
+
275
+ self.context_dim = context_dim
276
+ self.context_pre_only = context_pre_only
277
+
278
+ self.to_q = nn.Linear(dim, self.inner_dim)
279
+ self.to_k = nn.Linear(dim, self.inner_dim)
280
+ self.to_v = nn.Linear(dim, self.inner_dim)
281
+
282
+ if self.context_dim is not None:
283
+ self.to_k_c = nn.Linear(context_dim, self.inner_dim)
284
+ self.to_v_c = nn.Linear(context_dim, self.inner_dim)
285
+ if self.context_pre_only is not None:
286
+ self.to_q_c = nn.Linear(context_dim, self.inner_dim)
287
+
288
+ self.to_out = nn.ModuleList([])
289
+ self.to_out.append(nn.Linear(self.inner_dim, dim))
290
+ self.to_out.append(nn.Dropout(dropout))
291
+
292
+ if self.context_pre_only is not None and not self.context_pre_only:
293
+ self.to_out_c = nn.Linear(self.inner_dim, dim)
294
+
295
+ def forward(
296
+ self,
297
+ x: float['b n d'], # noised input x
298
+ c: float['b n d'] = None, # context c
299
+ mask: bool['b n'] | None = None,
300
+ rope = None, # rotary position embedding for x
301
+ c_rope = None, # rotary position embedding for c
302
+ ) -> torch.Tensor:
303
+ if c is not None:
304
+ return self.processor(self, x, c = c, mask = mask, rope = rope, c_rope = c_rope)
305
+ else:
306
+ return self.processor(self, x, mask = mask, rope = rope)
307
+
308
+
309
+ # Attention processor
310
+
311
+ class AttnProcessor:
312
+ def __init__(self):
313
+ pass
314
+
315
+ def __call__(
316
+ self,
317
+ attn: Attention,
318
+ x: float['b n d'], # noised input x
319
+ mask: bool['b n'] | None = None,
320
+ rope = None, # rotary position embedding
321
+ ) -> torch.FloatTensor:
322
+
323
+ batch_size = x.shape[0]
324
+
325
+ # `sample` projections.
326
+ query = attn.to_q(x)
327
+ key = attn.to_k(x)
328
+ value = attn.to_v(x)
329
+
330
+ # apply rotary position embedding
331
+ if rope is not None:
332
+ freqs, xpos_scale = rope
333
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
334
+
335
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
336
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
337
+
338
+ # attention
339
+ inner_dim = key.shape[-1]
340
+ head_dim = inner_dim // attn.heads
341
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
342
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
343
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
344
+
345
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
346
+ if mask is not None:
347
+ attn_mask = mask
348
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
349
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
350
+ else:
351
+ attn_mask = None
352
+
353
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
354
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
355
+ x = x.to(query.dtype)
356
+
357
+ # linear proj
358
+ x = attn.to_out[0](x)
359
+ # dropout
360
+ x = attn.to_out[1](x)
361
+
362
+ if mask is not None:
363
+ mask = rearrange(mask, 'b n -> b n 1')
364
+ x = x.masked_fill(~mask, 0.)
365
+
366
+ return x
367
+
368
+
369
+ # Joint Attention processor for MM-DiT
370
+ # modified from diffusers/src/diffusers/models/attention_processor.py
371
+
372
+ class JointAttnProcessor:
373
+ def __init__(self):
374
+ pass
375
+
376
+ def __call__(
377
+ self,
378
+ attn: Attention,
379
+ x: float['b n d'], # noised input x
380
+ c: float['b nt d'] = None, # context c, here text
381
+ mask: bool['b n'] | None = None,
382
+ rope = None, # rotary position embedding for x
383
+ c_rope = None, # rotary position embedding for c
384
+ ) -> torch.FloatTensor:
385
+ residual = x
386
+
387
+ batch_size = c.shape[0]
388
+
389
+ # `sample` projections.
390
+ query = attn.to_q(x)
391
+ key = attn.to_k(x)
392
+ value = attn.to_v(x)
393
+
394
+ # `context` projections.
395
+ c_query = attn.to_q_c(c)
396
+ c_key = attn.to_k_c(c)
397
+ c_value = attn.to_v_c(c)
398
+
399
+ # apply rope for context and noised input independently
400
+ if rope is not None:
401
+ freqs, xpos_scale = rope
402
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
403
+ query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
404
+ key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
405
+ if c_rope is not None:
406
+ freqs, xpos_scale = c_rope
407
+ q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if xpos_scale is not None else (1., 1.)
408
+ c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
409
+ c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
410
+
411
+ # attention
412
+ query = torch.cat([query, c_query], dim=1)
413
+ key = torch.cat([key, c_key], dim=1)
414
+ value = torch.cat([value, c_value], dim=1)
415
+
416
+ inner_dim = key.shape[-1]
417
+ head_dim = inner_dim // attn.heads
418
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
419
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
420
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
421
+
422
+ # mask. e.g. inference got a batch with different target durations, mask out the padding
423
+ if mask is not None:
424
+ attn_mask = F.pad(mask, (0, c.shape[1]), value = True) # no mask for c (text)
425
+ attn_mask = rearrange(attn_mask, 'b n -> b 1 1 n')
426
+ attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
427
+ else:
428
+ attn_mask = None
429
+
430
+ x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
431
+ x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
432
+ x = x.to(query.dtype)
433
+
434
+ # Split the attention outputs.
435
+ x, c = (
436
+ x[:, :residual.shape[1]],
437
+ x[:, residual.shape[1]:],
438
+ )
439
+
440
+ # linear proj
441
+ x = attn.to_out[0](x)
442
+ # dropout
443
+ x = attn.to_out[1](x)
444
+ if not attn.context_pre_only:
445
+ c = attn.to_out_c(c)
446
+
447
+ if mask is not None:
448
+ mask = rearrange(mask, 'b n -> b n 1')
449
+ x = x.masked_fill(~mask, 0.)
450
+ # c = c.masked_fill(~mask, 0.) # no mask for c (text)
451
+
452
+ return x, c
453
+
454
+
455
+ # DiT Block
456
+
457
+ class DiTBlock(nn.Module):
458
+
459
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1):
460
+ super().__init__()
461
+
462
+ self.attn_norm = AdaLayerNormZero(dim)
463
+ self.attn = Attention(
464
+ processor = AttnProcessor(),
465
+ dim = dim,
466
+ heads = heads,
467
+ dim_head = dim_head,
468
+ dropout = dropout,
469
+ )
470
+
471
+ self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
472
+ self.ff = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
473
+
474
+ def forward(self, x, t, mask = None, rope = None): # x: noised input, t: time embedding
475
+ # pre-norm & modulation for attention input
476
+ norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
477
+
478
+ # attention
479
+ attn_output = self.attn(x=norm, mask=mask, rope=rope)
480
+
481
+ # process attention output for input x
482
+ x = x + gate_msa.unsqueeze(1) * attn_output
483
+
484
+ norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
485
+ ff_output = self.ff(norm)
486
+ x = x + gate_mlp.unsqueeze(1) * ff_output
487
+
488
+ return x
489
+
490
+
491
+ # MMDiT Block https://arxiv.org/abs/2403.03206
492
+
493
+ class MMDiTBlock(nn.Module):
494
+ r"""
495
+ modified from diffusers/src/diffusers/models/attention.py
496
+
497
+ notes.
498
+ _c: context related. text, cond, etc. (left part in sd3 fig2.b)
499
+ _x: noised input related. (right part)
500
+ context_pre_only: last layer only do prenorm + modulation cuz no more ffn
501
+ """
502
+
503
+ def __init__(self, dim, heads, dim_head, ff_mult = 4, dropout = 0.1, context_pre_only = False):
504
+ super().__init__()
505
+
506
+ self.context_pre_only = context_pre_only
507
+
508
+ self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
509
+ self.attn_norm_x = AdaLayerNormZero(dim)
510
+ self.attn = Attention(
511
+ processor = JointAttnProcessor(),
512
+ dim = dim,
513
+ heads = heads,
514
+ dim_head = dim_head,
515
+ dropout = dropout,
516
+ context_dim = dim,
517
+ context_pre_only = context_pre_only,
518
+ )
519
+
520
+ if not context_pre_only:
521
+ self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
522
+ self.ff_c = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
523
+ else:
524
+ self.ff_norm_c = None
525
+ self.ff_c = None
526
+ self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
527
+ self.ff_x = FeedForward(dim = dim, mult = ff_mult, dropout = dropout, approximate = "tanh")
528
+
529
+ def forward(self, x, c, t, mask = None, rope = None, c_rope = None): # x: noised input, c: context, t: time embedding
530
+ # pre-norm & modulation for attention input
531
+ if self.context_pre_only:
532
+ norm_c = self.attn_norm_c(c, t)
533
+ else:
534
+ norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
535
+ norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
536
+
537
+ # attention
538
+ x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
539
+
540
+ # process attention output for context c
541
+ if self.context_pre_only:
542
+ c = None
543
+ else: # if not last layer
544
+ c = c + c_gate_msa.unsqueeze(1) * c_attn_output
545
+
546
+ norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
547
+ c_ff_output = self.ff_c(norm_c)
548
+ c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
549
+
550
+ # process attention output for input x
551
+ x = x + x_gate_msa.unsqueeze(1) * x_attn_output
552
+
553
+ norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
554
+ x_ff_output = self.ff_x(norm_x)
555
+ x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
556
+
557
+ return c, x
558
+
559
+
560
+ # time step conditioning embedding
561
+
562
+ class TimestepEmbedding(nn.Module):
563
+ def __init__(self, dim, freq_embed_dim=256):
564
+ super().__init__()
565
+ self.time_embed = SinusPositionEmbedding(freq_embed_dim)
566
+ self.time_mlp = nn.Sequential(
567
+ nn.Linear(freq_embed_dim, dim),
568
+ nn.SiLU(),
569
+ nn.Linear(dim, dim)
570
+ )
571
+
572
+ def forward(self, timestep: float['b']):
573
+ time_hidden = self.time_embed(timestep)
574
+ time = self.time_mlp(time_hidden) # b d
575
+ return time
model/model_trainer.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import gc
5
+ from tqdm import tqdm
6
+ import wandb
7
+
8
+ import torch
9
+ from torch.optim import AdamW
10
+ from torch.utils.data import DataLoader, Dataset, SequentialSampler
11
+ from torch.optim.lr_scheduler import LinearLR, SequentialLR
12
+
13
+ from einops import rearrange
14
+
15
+ from accelerate import Accelerator
16
+ from accelerate.utils import DistributedDataParallelKwargs
17
+
18
+ from ema_pytorch import EMA
19
+
20
+ from model import CFM
21
+ from model.utils import exists, default
22
+ from model.dataset import DynamicBatchSampler, collate_fn
23
+
24
+
25
+ # trainer
26
+
27
+ class Trainer:
28
+ def __init__(
29
+ self,
30
+ model: CFM,
31
+ epochs,
32
+ learning_rate,
33
+ num_warmup_updates = 20000,
34
+ save_per_updates = 1000,
35
+ checkpoint_path = None,
36
+ batch_size = 32,
37
+ batch_size_type: str = "sample",
38
+ max_samples = 32,
39
+ grad_accumulation_steps = 1,
40
+ max_grad_norm = 1.0,
41
+ noise_scheduler: str | None = None,
42
+ duration_predictor: torch.nn.Module | None = None,
43
+ wandb_project = "test_e2-tts",
44
+ wandb_run_name = "test_run",
45
+ wandb_resume_id: str = None,
46
+ last_per_steps = None,
47
+ accelerate_kwargs: dict = dict(),
48
+ ema_kwargs: dict = dict()
49
+ ):
50
+
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
52
+
53
+ self.accelerator = Accelerator(
54
+ log_with = "wandb",
55
+ kwargs_handlers = [ddp_kwargs],
56
+ gradient_accumulation_steps = grad_accumulation_steps,
57
+ **accelerate_kwargs
58
+ )
59
+
60
+ if exists(wandb_resume_id):
61
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name, 'id': wandb_resume_id}}
62
+ else:
63
+ init_kwargs={"wandb": {"resume": "allow", "name": wandb_run_name}}
64
+ self.accelerator.init_trackers(
65
+ project_name = wandb_project,
66
+ init_kwargs=init_kwargs,
67
+ config={"epochs": epochs,
68
+ "learning_rate": learning_rate,
69
+ "num_warmup_updates": num_warmup_updates,
70
+ "batch_size": batch_size,
71
+ "batch_size_type": batch_size_type,
72
+ "max_samples": max_samples,
73
+ "grad_accumulation_steps": grad_accumulation_steps,
74
+ "max_grad_norm": max_grad_norm,
75
+ "gpus": self.accelerator.num_processes,
76
+ "noise_scheduler": noise_scheduler}
77
+ )
78
+
79
+ self.model = model
80
+
81
+ if self.is_main:
82
+ self.ema_model = EMA(
83
+ model,
84
+ include_online_model = False,
85
+ **ema_kwargs
86
+ )
87
+
88
+ self.ema_model.to(self.accelerator.device)
89
+
90
+ self.epochs = epochs
91
+ self.num_warmup_updates = num_warmup_updates
92
+ self.save_per_updates = save_per_updates
93
+ self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
94
+ self.checkpoint_path = default(checkpoint_path, 'ckpts/test_e2-tts')
95
+
96
+ self.batch_size = batch_size
97
+ self.batch_size_type = batch_size_type
98
+ self.max_samples = max_samples
99
+ self.grad_accumulation_steps = grad_accumulation_steps
100
+ self.max_grad_norm = max_grad_norm
101
+
102
+ self.noise_scheduler = noise_scheduler
103
+
104
+ self.duration_predictor = duration_predictor
105
+
106
+ self.optimizer = AdamW(model.parameters(), lr=learning_rate)
107
+ self.model, self.optimizer = self.accelerator.prepare(
108
+ self.model, self.optimizer
109
+ )
110
+
111
+ @property
112
+ def is_main(self):
113
+ return self.accelerator.is_main_process
114
+
115
+ def save_checkpoint(self, step, last=False):
116
+ self.accelerator.wait_for_everyone()
117
+ if self.is_main:
118
+ checkpoint = dict(
119
+ model_state_dict = self.accelerator.unwrap_model(self.model).state_dict(),
120
+ optimizer_state_dict = self.accelerator.unwrap_model(self.optimizer).state_dict(),
121
+ ema_model_state_dict = self.ema_model.state_dict(),
122
+ scheduler_state_dict = self.scheduler.state_dict(),
123
+ step = step
124
+ )
125
+ if not os.path.exists(self.checkpoint_path):
126
+ os.makedirs(self.checkpoint_path)
127
+ if last == True:
128
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
129
+ print(f"Saved last checkpoint at step {step}")
130
+ else:
131
+ self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")
132
+
133
+ def load_checkpoint(self):
134
+ if not exists(self.checkpoint_path) or not os.path.exists(self.checkpoint_path) or not os.listdir(self.checkpoint_path):
135
+ return 0
136
+
137
+ self.accelerator.wait_for_everyone()
138
+ if "model_last.pt" in os.listdir(self.checkpoint_path):
139
+ latest_checkpoint = "model_last.pt"
140
+ else:
141
+ latest_checkpoint = sorted([f for f in os.listdir(self.checkpoint_path) if f.endswith('.pt')], key=lambda x: int(''.join(filter(str.isdigit, x))))[-1]
142
+ # checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location=self.accelerator.device) # rather use accelerator.load_state ಥ_ಥ
143
+ checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", weights_only=True, map_location="cpu")
144
+
145
+ if self.is_main:
146
+ self.ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
147
+
148
+ if 'step' in checkpoint:
149
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
150
+ self.accelerator.unwrap_model(self.optimizer).load_state_dict(checkpoint['optimizer_state_dict'])
151
+ if self.scheduler:
152
+ self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
153
+ step = checkpoint['step']
154
+ else:
155
+ checkpoint['model_state_dict'] = {k.replace("ema_model.", ""): v for k, v in checkpoint['ema_model_state_dict'].items() if k not in ["initted", "step"]}
156
+ self.accelerator.unwrap_model(self.model).load_state_dict(checkpoint['model_state_dict'])
157
+ step = 0
158
+
159
+ del checkpoint; gc.collect()
160
+ return step
161
+
162
+ def train(self, train_dataset: Dataset, num_workers=16, resumable_with_seed: int = None):
163
+
164
+ if exists(resumable_with_seed):
165
+ generator = torch.Generator()
166
+ generator.manual_seed(resumable_with_seed)
167
+ else:
168
+ generator = None
169
+
170
+ if self.batch_size_type == "sample":
171
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
172
+ batch_size=self.batch_size, shuffle=True, generator=generator)
173
+ elif self.batch_size_type == "frame":
174
+ self.accelerator.even_batches = False
175
+ sampler = SequentialSampler(train_dataset)
176
+ batch_sampler = DynamicBatchSampler(sampler, self.batch_size, max_samples=self.max_samples, random_seed=resumable_with_seed, drop_last=False)
177
+ train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, num_workers=num_workers, pin_memory=True, persistent_workers=True,
178
+ batch_sampler=batch_sampler)
179
+ else:
180
+ raise ValueError(f"batch_size_type must be either 'sample' or 'frame', but received {self.batch_size_type}")
181
+
182
+ # accelerator.prepare() dispatches batches to devices;
183
+ # which means the length of dataloader calculated before, should consider the number of devices
184
+ warmup_steps = self.num_warmup_updates * self.accelerator.num_processes # consider a fixed warmup steps while using accelerate multi-gpu ddp
185
+ # otherwise by default with split_batches=False, warmup steps change with num_processes
186
+ total_steps = len(train_dataloader) * self.epochs / self.grad_accumulation_steps
187
+ decay_steps = total_steps - warmup_steps
188
+ warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
189
+ decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
190
+ self.scheduler = SequentialLR(self.optimizer,
191
+ schedulers=[warmup_scheduler, decay_scheduler],
192
+ milestones=[warmup_steps])
193
+ train_dataloader, self.scheduler = self.accelerator.prepare(train_dataloader, self.scheduler) # actual steps = 1 gpu steps / gpus
194
+ start_step = self.load_checkpoint()
195
+ global_step = start_step
196
+
197
+ if exists(resumable_with_seed):
198
+ orig_epoch_step = len(train_dataloader)
199
+ skipped_epoch = int(start_step // orig_epoch_step)
200
+ skipped_batch = start_step % orig_epoch_step
201
+ skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
202
+ else:
203
+ skipped_epoch = 0
204
+
205
+ for epoch in range(skipped_epoch, self.epochs):
206
+ self.model.train()
207
+ if exists(resumable_with_seed) and epoch == skipped_epoch:
208
+ progress_bar = tqdm(skipped_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process,
209
+ initial=skipped_batch, total=orig_epoch_step)
210
+ else:
211
+ progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{self.epochs}", unit="step", disable=not self.accelerator.is_local_main_process)
212
+
213
+ for batch in progress_bar:
214
+ with self.accelerator.accumulate(self.model):
215
+ text_inputs = batch['text']
216
+ mel_spec = rearrange(batch['mel'], 'b d n -> b n d')
217
+ mel_lengths = batch["mel_lengths"]
218
+
219
+ # TODO. add duration predictor training
220
+ if self.duration_predictor is not None and self.accelerator.is_local_main_process:
221
+ dur_loss = self.duration_predictor(mel_spec, lens=batch.get('durations'))
222
+ self.accelerator.log({"duration loss": dur_loss.item()}, step=global_step)
223
+
224
+ loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler)
225
+ self.accelerator.backward(loss)
226
+
227
+ if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
228
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
229
+
230
+ self.optimizer.step()
231
+ self.scheduler.step()
232
+ self.optimizer.zero_grad()
233
+
234
+ if self.is_main:
235
+ self.ema_model.update()
236
+
237
+ global_step += 1
238
+
239
+ if self.accelerator.is_local_main_process:
240
+ self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)
241
+
242
+ progress_bar.set_postfix(step=str(global_step), loss=loss.item())
243
+
244
+ if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
245
+ self.save_checkpoint(global_step)
246
+
247
+ if global_step % self.last_per_steps == 0:
248
+ self.save_checkpoint(global_step, last=True)
249
+
250
+ self.accelerator.end_training()
model/model_utils.py ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import re
5
+ import math
6
+ import random
7
+ import string
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+
11
+ import matplotlib
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.nn.utils.rnn import pad_sequence
18
+ import torchaudio
19
+
20
+ import einx
21
+ from einops import rearrange, reduce
22
+
23
+ import jieba
24
+ from pypinyin import lazy_pinyin, Style
25
+
26
+ from model.ecapa_tdnn import ECAPA_TDNN_SMALL
27
+ from model.modules import MelSpec
28
+
29
+
30
+ # seed everything
31
+
32
+ def seed_everything(seed = 0):
33
+ random.seed(seed)
34
+ os.environ['PYTHONHASHSEED'] = str(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed(seed)
37
+ torch.cuda.manual_seed_all(seed)
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+ # helpers
42
+
43
+ def exists(v):
44
+ return v is not None
45
+
46
+ def default(v, d):
47
+ return v if exists(v) else d
48
+
49
+ # tensor helpers
50
+
51
+ def lens_to_mask(
52
+ t: int['b'],
53
+ length: int | None = None
54
+ ) -> bool['b n']:
55
+
56
+ if not exists(length):
57
+ length = t.amax()
58
+
59
+ seq = torch.arange(length, device = t.device)
60
+ return einx.less('n, b -> b n', seq, t)
61
+
62
+ def mask_from_start_end_indices(
63
+ seq_len: int['b'],
64
+ start: int['b'],
65
+ end: int['b']
66
+ ):
67
+ max_seq_len = seq_len.max().item()
68
+ seq = torch.arange(max_seq_len, device = start.device).long()
69
+ return einx.greater_equal('n, b -> b n', seq, start) & einx.less('n, b -> b n', seq, end)
70
+
71
+ def mask_from_frac_lengths(
72
+ seq_len: int['b'],
73
+ frac_lengths: float['b']
74
+ ):
75
+ lengths = (frac_lengths * seq_len).long()
76
+ max_start = seq_len - lengths
77
+
78
+ rand = torch.rand_like(frac_lengths)
79
+ start = (max_start * rand).long().clamp(min = 0)
80
+ end = start + lengths
81
+
82
+ return mask_from_start_end_indices(seq_len, start, end)
83
+
84
+ def maybe_masked_mean(
85
+ t: float['b n d'],
86
+ mask: bool['b n'] = None
87
+ ) -> float['b d']:
88
+
89
+ if not exists(mask):
90
+ return t.mean(dim = 1)
91
+
92
+ t = einx.where('b n, b n d, -> b n d', mask, t, 0.)
93
+ num = reduce(t, 'b n d -> b d', 'sum')
94
+ den = reduce(mask.float(), 'b n -> b', 'sum')
95
+
96
+ return einx.divide('b d, b -> b d', num, den.clamp(min = 1.))
97
+
98
+
99
+ # simple utf-8 tokenizer, since paper went character based
100
+ def list_str_to_tensor(
101
+ text: list[str],
102
+ padding_value = -1
103
+ ) -> int['b nt']:
104
+ list_tensors = [torch.tensor([*bytes(t, 'UTF-8')]) for t in text] # ByT5 style
105
+ text = pad_sequence(list_tensors, padding_value = padding_value, batch_first = True)
106
+ return text
107
+
108
+ # char tokenizer, based on custom dataset's extracted .txt file
109
+ def list_str_to_idx(
110
+ text: list[str] | list[list[str]],
111
+ vocab_char_map: dict[str, int], # {char: idx}
112
+ padding_value = -1
113
+ ) -> int['b nt']:
114
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
115
+ text = pad_sequence(list_idx_tensors, padding_value = padding_value, batch_first = True)
116
+ return text
117
+
118
+
119
+ # Get tokenizer
120
+
121
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
122
+ '''
123
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
124
+ - "char" for char-wise tokenizer, need .txt vocab_file
125
+ - "byte" for utf-8 tokenizer
126
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
127
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
128
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
129
+ - if use "byte", set to 256 (unicode byte range)
130
+ '''
131
+ if tokenizer in ["pinyin", "char"]:
132
+ with open (f"data/{dataset_name}_{tokenizer}/vocab.txt", "r", encoding="utf-8") as f:
133
+ vocab_char_map = {}
134
+ for i, char in enumerate(f):
135
+ vocab_char_map[char[:-1]] = i
136
+ vocab_size = len(vocab_char_map)
137
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
138
+
139
+ elif tokenizer == "byte":
140
+ vocab_char_map = None
141
+ vocab_size = 256
142
+ elif tokenizer == "custom":
143
+ with open (dataset_name, "r", encoding="utf-8") as f:
144
+ vocab_char_map = {}
145
+ for i, char in enumerate(f):
146
+ vocab_char_map[char[:-1]] = i
147
+ vocab_size = len(vocab_char_map)
148
+
149
+ return vocab_char_map, vocab_size
150
+
151
+
152
+ # convert char to pinyin
153
+
154
+ def convert_char_to_pinyin(text_list, polyphone = True):
155
+ final_text_list = []
156
+ god_knows_why_en_testset_contains_zh_quote = str.maketrans({'“': '"', '”': '"', '‘': "'", '’': "'"}) # in case librispeech (orig no-pc) test-clean
157
+ custom_trans = str.maketrans({';': ','}) # add custom trans here, to address oov
158
+ for text in text_list:
159
+ char_list = []
160
+ text = text.translate(god_knows_why_en_testset_contains_zh_quote)
161
+ text = text.translate(custom_trans)
162
+ for seg in jieba.cut(text):
163
+ seg_byte_len = len(bytes(seg, 'UTF-8'))
164
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
165
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
166
+ char_list.append(" ")
167
+ char_list.extend(seg)
168
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure chinese characters
169
+ seg = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
170
+ for c in seg:
171
+ if c not in "。,、;:?!《》【】—…":
172
+ char_list.append(" ")
173
+ char_list.append(c)
174
+ else: # if mixed chinese characters, alphabets and symbols
175
+ for c in seg:
176
+ if ord(c) < 256:
177
+ char_list.extend(c)
178
+ else:
179
+ if c not in "。,、;:?!《》【】—…":
180
+ char_list.append(" ")
181
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
182
+ else: # if is zh punc
183
+ char_list.append(c)
184
+ final_text_list.append(char_list)
185
+
186
+ return final_text_list
187
+
188
+
189
+ # save spectrogram
190
+ def save_spectrogram(spectrogram, path):
191
+ plt.figure(figsize=(12, 4))
192
+ plt.imshow(spectrogram, origin='lower', aspect='auto')
193
+ plt.colorbar()
194
+ plt.savefig(path)
195
+ plt.close()
196
+
197
+
198
+ # seedtts testset metainfo: utt, prompt_text, prompt_wav, gt_text, gt_wav
199
+ def get_seedtts_testset_metainfo(metalst):
200
+ f = open(metalst); lines = f.readlines(); f.close()
201
+ metainfo = []
202
+ for line in lines:
203
+ if len(line.strip().split('|')) == 5:
204
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
205
+ elif len(line.strip().split('|')) == 4:
206
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
207
+ gt_wav = os.path.join(os.path.dirname(metalst), "wavs", utt + ".wav")
208
+ if not os.path.isabs(prompt_wav):
209
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
210
+ metainfo.append((utt, prompt_text, prompt_wav, gt_text, gt_wav))
211
+ return metainfo
212
+
213
+
214
+ # librispeech test-clean metainfo: gen_utt, ref_txt, ref_wav, gen_txt, gen_wav
215
+ def get_librispeech_test_clean_metainfo(metalst, librispeech_test_clean_path):
216
+ f = open(metalst); lines = f.readlines(); f.close()
217
+ metainfo = []
218
+ for line in lines:
219
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
220
+
221
+ # ref_txt = ref_txt[0] + ref_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
222
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
223
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
224
+
225
+ # gen_txt = gen_txt[0] + gen_txt[1:].lower() + '.' # if use librispeech test-clean (no-pc)
226
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
227
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
228
+
229
+ metainfo.append((gen_utt, ref_txt, ref_wav, " " + gen_txt, gen_wav))
230
+
231
+ return metainfo
232
+
233
+
234
+ # padded to max length mel batch
235
+ def padded_mel_batch(ref_mels):
236
+ max_mel_length = torch.LongTensor([mel.shape[-1] for mel in ref_mels]).amax()
237
+ padded_ref_mels = []
238
+ for mel in ref_mels:
239
+ padded_ref_mel = F.pad(mel, (0, max_mel_length - mel.shape[-1]), value = 0)
240
+ padded_ref_mels.append(padded_ref_mel)
241
+ padded_ref_mels = torch.stack(padded_ref_mels)
242
+ padded_ref_mels = rearrange(padded_ref_mels, 'b d n -> b n d')
243
+ return padded_ref_mels
244
+
245
+
246
+ # get prompts from metainfo containing: utt, prompt_text, prompt_wav, gt_text, gt_wav
247
+
248
+ def get_inference_prompt(
249
+ metainfo,
250
+ speed = 1., tokenizer = "pinyin", polyphone = True,
251
+ target_sample_rate = 24000, n_mel_channels = 100, hop_length = 256, target_rms = 0.1,
252
+ use_truth_duration = False,
253
+ infer_batch_size = 1, num_buckets = 200, min_secs = 3, max_secs = 40,
254
+ ):
255
+ prompts_all = []
256
+
257
+ min_tokens = min_secs * target_sample_rate // hop_length
258
+ max_tokens = max_secs * target_sample_rate // hop_length
259
+
260
+ batch_accum = [0] * num_buckets
261
+ utts, ref_rms_list, ref_mels, ref_mel_lens, total_mel_lens, final_text_list = \
262
+ ([[] for _ in range(num_buckets)] for _ in range(6))
263
+
264
+ mel_spectrogram = MelSpec(target_sample_rate=target_sample_rate, n_mel_channels=n_mel_channels, hop_length=hop_length)
265
+
266
+ for utt, prompt_text, prompt_wav, gt_text, gt_wav in tqdm(metainfo, desc="Processing prompts..."):
267
+
268
+ # Audio
269
+ ref_audio, ref_sr = torchaudio.load(prompt_wav)
270
+ ref_rms = torch.sqrt(torch.mean(torch.square(ref_audio)))
271
+ if ref_rms < target_rms:
272
+ ref_audio = ref_audio * target_rms / ref_rms
273
+ assert ref_audio.shape[-1] > 5000, f"Empty prompt wav: {prompt_wav}, or torchaudio backend issue."
274
+ if ref_sr != target_sample_rate:
275
+ resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
276
+ ref_audio = resampler(ref_audio)
277
+
278
+ # Text
279
+ if len(prompt_text[-1].encode('utf-8')) == 1:
280
+ prompt_text = prompt_text + " "
281
+ text = [prompt_text + gt_text]
282
+ if tokenizer == "pinyin":
283
+ text_list = convert_char_to_pinyin(text, polyphone = polyphone)
284
+ else:
285
+ text_list = text
286
+
287
+ # Duration, mel frame length
288
+ ref_mel_len = ref_audio.shape[-1] // hop_length
289
+ if use_truth_duration:
290
+ gt_audio, gt_sr = torchaudio.load(gt_wav)
291
+ if gt_sr != target_sample_rate:
292
+ resampler = torchaudio.transforms.Resample(gt_sr, target_sample_rate)
293
+ gt_audio = resampler(gt_audio)
294
+ total_mel_len = ref_mel_len + int(gt_audio.shape[-1] / hop_length / speed)
295
+
296
+ # # test vocoder resynthesis
297
+ # ref_audio = gt_audio
298
+ else:
299
+ zh_pause_punc = r"。,、;:?!"
300
+ ref_text_len = len(prompt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, prompt_text))
301
+ gen_text_len = len(gt_text.encode('utf-8')) + 3 * len(re.findall(zh_pause_punc, gt_text))
302
+ total_mel_len = ref_mel_len + int(ref_mel_len / ref_text_len * gen_text_len / speed)
303
+
304
+ # to mel spectrogram
305
+ ref_mel = mel_spectrogram(ref_audio)
306
+ ref_mel = rearrange(ref_mel, '1 d n -> d n')
307
+
308
+ # deal with batch
309
+ assert infer_batch_size > 0, "infer_batch_size should be greater than 0."
310
+ assert min_tokens <= total_mel_len <= max_tokens, \
311
+ f"Audio {utt} has duration {total_mel_len*hop_length//target_sample_rate}s out of range [{min_secs}, {max_secs}]."
312
+ bucket_i = math.floor((total_mel_len - min_tokens) / (max_tokens - min_tokens + 1) * num_buckets)
313
+
314
+ utts[bucket_i].append(utt)
315
+ ref_rms_list[bucket_i].append(ref_rms)
316
+ ref_mels[bucket_i].append(ref_mel)
317
+ ref_mel_lens[bucket_i].append(ref_mel_len)
318
+ total_mel_lens[bucket_i].append(total_mel_len)
319
+ final_text_list[bucket_i].extend(text_list)
320
+
321
+ batch_accum[bucket_i] += total_mel_len
322
+
323
+ if batch_accum[bucket_i] >= infer_batch_size:
324
+ # print(f"\n{len(ref_mels[bucket_i][0][0])}\n{ref_mel_lens[bucket_i]}\n{total_mel_lens[bucket_i]}")
325
+ prompts_all.append((
326
+ utts[bucket_i],
327
+ ref_rms_list[bucket_i],
328
+ padded_mel_batch(ref_mels[bucket_i]),
329
+ ref_mel_lens[bucket_i],
330
+ total_mel_lens[bucket_i],
331
+ final_text_list[bucket_i]
332
+ ))
333
+ batch_accum[bucket_i] = 0
334
+ utts[bucket_i], ref_rms_list[bucket_i], ref_mels[bucket_i], ref_mel_lens[bucket_i], total_mel_lens[bucket_i], final_text_list[bucket_i] = [], [], [], [], [], []
335
+
336
+ # add residual
337
+ for bucket_i, bucket_frames in enumerate(batch_accum):
338
+ if bucket_frames > 0:
339
+ prompts_all.append((
340
+ utts[bucket_i],
341
+ ref_rms_list[bucket_i],
342
+ padded_mel_batch(ref_mels[bucket_i]),
343
+ ref_mel_lens[bucket_i],
344
+ total_mel_lens[bucket_i],
345
+ final_text_list[bucket_i]
346
+ ))
347
+ # not only leave easy work for last workers
348
+ random.seed(666)
349
+ random.shuffle(prompts_all)
350
+
351
+ return prompts_all
352
+
353
+
354
+ # get wav_res_ref_text of seed-tts test metalst
355
+ # https://github.com/BytedanceSpeech/seed-tts-eval
356
+
357
+ def get_seed_tts_test(metalst, gen_wav_dir, gpus):
358
+ f = open(metalst)
359
+ lines = f.readlines()
360
+ f.close()
361
+
362
+ test_set_ = []
363
+ for line in tqdm(lines):
364
+ if len(line.strip().split('|')) == 5:
365
+ utt, prompt_text, prompt_wav, gt_text, gt_wav = line.strip().split('|')
366
+ elif len(line.strip().split('|')) == 4:
367
+ utt, prompt_text, prompt_wav, gt_text = line.strip().split('|')
368
+
369
+ if not os.path.exists(os.path.join(gen_wav_dir, utt + '.wav')):
370
+ continue
371
+ gen_wav = os.path.join(gen_wav_dir, utt + '.wav')
372
+ if not os.path.isabs(prompt_wav):
373
+ prompt_wav = os.path.join(os.path.dirname(metalst), prompt_wav)
374
+
375
+ test_set_.append((gen_wav, prompt_wav, gt_text))
376
+
377
+ num_jobs = len(gpus)
378
+ if num_jobs == 1:
379
+ return [(gpus[0], test_set_)]
380
+
381
+ wav_per_job = len(test_set_) // num_jobs + 1
382
+ test_set = []
383
+ for i in range(num_jobs):
384
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
385
+
386
+ return test_set
387
+
388
+
389
+ # get librispeech test-clean cross sentence test
390
+
391
+ def get_librispeech_test(metalst, gen_wav_dir, gpus, librispeech_test_clean_path, eval_ground_truth = False):
392
+ f = open(metalst)
393
+ lines = f.readlines()
394
+ f.close()
395
+
396
+ test_set_ = []
397
+ for line in tqdm(lines):
398
+ ref_utt, ref_dur, ref_txt, gen_utt, gen_dur, gen_txt = line.strip().split('\t')
399
+
400
+ if eval_ground_truth:
401
+ gen_spk_id, gen_chaptr_id, _ = gen_utt.split('-')
402
+ gen_wav = os.path.join(librispeech_test_clean_path, gen_spk_id, gen_chaptr_id, gen_utt + '.flac')
403
+ else:
404
+ if not os.path.exists(os.path.join(gen_wav_dir, gen_utt + '.wav')):
405
+ raise FileNotFoundError(f"Generated wav not found: {gen_utt}")
406
+ gen_wav = os.path.join(gen_wav_dir, gen_utt + '.wav')
407
+
408
+ ref_spk_id, ref_chaptr_id, _ = ref_utt.split('-')
409
+ ref_wav = os.path.join(librispeech_test_clean_path, ref_spk_id, ref_chaptr_id, ref_utt + '.flac')
410
+
411
+ test_set_.append((gen_wav, ref_wav, gen_txt))
412
+
413
+ num_jobs = len(gpus)
414
+ if num_jobs == 1:
415
+ return [(gpus[0], test_set_)]
416
+
417
+ wav_per_job = len(test_set_) // num_jobs + 1
418
+ test_set = []
419
+ for i in range(num_jobs):
420
+ test_set.append((gpus[i], test_set_[i*wav_per_job:(i+1)*wav_per_job]))
421
+
422
+ return test_set
423
+
424
+
425
+ # load asr model
426
+
427
+ def load_asr_model(lang, ckpt_dir = ""):
428
+ if lang == "zh":
429
+ from funasr import AutoModel
430
+ model = AutoModel(
431
+ model = os.path.join(ckpt_dir, "paraformer-zh"),
432
+ # vad_model = os.path.join(ckpt_dir, "fsmn-vad"),
433
+ # punc_model = os.path.join(ckpt_dir, "ct-punc"),
434
+ # spk_model = os.path.join(ckpt_dir, "cam++"),
435
+ disable_update=True,
436
+ ) # following seed-tts setting
437
+ elif lang == "en":
438
+ from faster_whisper import WhisperModel
439
+ model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
440
+ model = WhisperModel(model_size, device="cuda", compute_type="float16")
441
+ return model
442
+
443
+
444
+ # WER Evaluation, the way Seed-TTS does
445
+
446
+ def run_asr_wer(args):
447
+ rank, lang, test_set, ckpt_dir = args
448
+
449
+ if lang == "zh":
450
+ import zhconv
451
+ torch.cuda.set_device(rank)
452
+ elif lang == "en":
453
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
454
+ else:
455
+ raise NotImplementedError("lang support only 'zh' (funasr paraformer-zh), 'en' (faster-whisper-large-v3), for now.")
456
+
457
+ asr_model = load_asr_model(lang, ckpt_dir = ckpt_dir)
458
+
459
+ from zhon.hanzi import punctuation
460
+ punctuation_all = punctuation + string.punctuation
461
+ wers = []
462
+
463
+ from jiwer import compute_measures
464
+ for gen_wav, prompt_wav, truth in tqdm(test_set):
465
+ if lang == "zh":
466
+ res = asr_model.generate(input=gen_wav, batch_size_s=300, disable_pbar=True)
467
+ hypo = res[0]["text"]
468
+ hypo = zhconv.convert(hypo, 'zh-cn')
469
+ elif lang == "en":
470
+ segments, _ = asr_model.transcribe(gen_wav, beam_size=5, language="en")
471
+ hypo = ''
472
+ for segment in segments:
473
+ hypo = hypo + ' ' + segment.text
474
+
475
+ # raw_truth = truth
476
+ # raw_hypo = hypo
477
+
478
+ for x in punctuation_all:
479
+ truth = truth.replace(x, '')
480
+ hypo = hypo.replace(x, '')
481
+
482
+ truth = truth.replace(' ', ' ')
483
+ hypo = hypo.replace(' ', ' ')
484
+
485
+ if lang == "zh":
486
+ truth = " ".join([x for x in truth])
487
+ hypo = " ".join([x for x in hypo])
488
+ elif lang == "en":
489
+ truth = truth.lower()
490
+ hypo = hypo.lower()
491
+
492
+ measures = compute_measures(truth, hypo)
493
+ wer = measures["wer"]
494
+
495
+ # ref_list = truth.split(" ")
496
+ # subs = measures["substitutions"] / len(ref_list)
497
+ # dele = measures["deletions"] / len(ref_list)
498
+ # inse = measures["insertions"] / len(ref_list)
499
+
500
+ wers.append(wer)
501
+
502
+ return wers
503
+
504
+
505
+ # SIM Evaluation
506
+
507
+ def run_sim(args):
508
+ rank, test_set, ckpt_dir = args
509
+ device = f"cuda:{rank}"
510
+
511
+ model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=None)
512
+ state_dict = torch.load(ckpt_dir, weights_only=True, map_location=lambda storage, loc: storage)
513
+ model.load_state_dict(state_dict['model'], strict=False)
514
+
515
+ use_gpu=True if torch.cuda.is_available() else False
516
+ if use_gpu:
517
+ model = model.cuda(device)
518
+ model.eval()
519
+
520
+ sim_list = []
521
+ for wav1, wav2, truth in tqdm(test_set):
522
+
523
+ wav1, sr1 = torchaudio.load(wav1)
524
+ wav2, sr2 = torchaudio.load(wav2)
525
+
526
+ resample1 = torchaudio.transforms.Resample(orig_freq=sr1, new_freq=16000)
527
+ resample2 = torchaudio.transforms.Resample(orig_freq=sr2, new_freq=16000)
528
+ wav1 = resample1(wav1)
529
+ wav2 = resample2(wav2)
530
+
531
+ if use_gpu:
532
+ wav1 = wav1.cuda(device)
533
+ wav2 = wav2.cuda(device)
534
+ with torch.no_grad():
535
+ emb1 = model(wav1)
536
+ emb2 = model(wav2)
537
+
538
+ sim = F.cosine_similarity(emb1, emb2)[0].item()
539
+ # print(f"VSim score between two audios: {sim:.4f} (-1.0, 1.0).")
540
+ sim_list.append(sim)
541
+
542
+ return sim_list
543
+
544
+
545
+ # filter func for dirty data with many repetitions
546
+
547
+ def repetition_found(text, length = 2, tolerance = 10):
548
+ pattern_count = defaultdict(int)
549
+ for i in range(len(text) - length + 1):
550
+ pattern = text[i:i + length]
551
+ pattern_count[pattern] += 1
552
+ for pattern, count in pattern_count.items():
553
+ if count > tolerance:
554
+ return True
555
+ return False
556
+
557
+
558
+ # load model checkpoint for inference
559
+
560
+ def load_checkpoint(model, ckpt_path, device, use_ema = True):
561
+ from ema_pytorch import EMA
562
+
563
+ ckpt_type = ckpt_path.split(".")[-1]
564
+ if ckpt_type == "safetensors":
565
+ from safetensors.torch import load_file
566
+ checkpoint = load_file(ckpt_path, device=device)
567
+ else:
568
+ checkpoint = torch.load(ckpt_path, weights_only=True, map_location=device)
569
+
570
+ if use_ema == True:
571
+ ema_model = EMA(model, include_online_model = False).to(device)
572
+ if ckpt_type == "safetensors":
573
+ ema_model.load_state_dict(checkpoint)
574
+ else:
575
+ ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
576
+ ema_model.copy_params_from_ema_to_model()
577
+ else:
578
+ model.load_state_dict(checkpoint['model_state_dict'])
579
+
580
+ return model