Amamrnaf commited on
Commit
1a73edf
·
1 Parent(s): a62a5cc
Files changed (1) hide show
  1. metaVoice.py +785 -9
metaVoice.py CHANGED
@@ -1,6 +1,43 @@
1
  from fam.llm.fast_inference import TTS
2
  import string
 
 
 
 
 
 
 
3
  import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def remove_punctuation(sentence):
6
  translator = str.maketrans('', '', string.punctuation)
@@ -11,8 +48,26 @@ def remove_punctuation(sentence):
11
 
12
  return sentence
13
 
14
- def run_audio_generation_v2(new_text,accent='None'):
15
- tts = TTS()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  new_text = new_text.replace('\n', ' ').replace('\r', '')
17
  new_text_mod = remove_punctuation(new_text)
18
 
@@ -20,11 +75,732 @@ def run_audio_generation_v2(new_text,accent='None'):
20
  for word in new_text_split:
21
  if len(word)>=2 and word.isupper():
22
  new_text = new_text.replace(word, " ".join([*word]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
- wav_file = tts.synthesise(
25
- text=new_text,
26
- spk_ref_path="./tmp/audio/speaker_wav.wav" # you can use any speaker reference file (WAV, OGG, MP3, FLAC, etc.)
27
- )
28
- sf.write('audio/output.wav', wav_file, samplerate=22050)
29
-
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fam.llm.fast_inference import TTS
2
  import string
3
+ import json
4
+ from glob import glob
5
+ import torch
6
+ import os
7
+ import torchaudio
8
+ import subprocess
9
+ import shutil
10
  import soundfile as sf
11
+ import pyloudnorm as pyln
12
+ import noisereduce as nr
13
+ from moviepy.editor import *
14
+ from pydub import AudioSegment
15
+ from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook, TiltedEncodec
16
+ from fam.llm.decoders import Decoder, EncodecDecoder
17
+ from fam.llm.enhancers import BaseEnhancer, get_enhancer
18
+ from fam.llm.model import GPT, GPTConfig
19
+ from fam.llm.utils import (
20
+ check_audio_file,
21
+ get_default_dtype,
22
+ get_default_use_kv_cache,
23
+ normalize_text,
24
+ )
25
+ from fam.quantiser.audio.speaker_encoder.model import SpeakerEncoder
26
+ from fam.quantiser.text.tokenise import TrainedBPETokeniser
27
+ import tyro
28
+ from huggingface_hub import snapshot_download
29
+ from typing import List, Literal, Optional, Tuple, Type, Union
30
+ import dataclasses
31
+ import hashlib
32
+ import json
33
+ import os
34
+ import pathlib
35
+ from contextlib import nullcontext
36
+ from dataclasses import dataclass
37
+ import tqdm
38
+ import tqdm.contrib.concurrent
39
+ import tempfile
40
+ import textwrap
41
 
42
  def remove_punctuation(sentence):
43
  translator = str.maketrans('', '', string.punctuation)
 
48
 
49
  return sentence
50
 
51
+ # def run_audio_generation_v2(new_text,accent='None'):
52
+ # tts = TTS()
53
+ # new_text = new_text.replace('\n', ' ').replace('\r', '')
54
+ # new_text_mod = remove_punctuation(new_text)
55
+
56
+ # new_text_split = new_text_mod.split()
57
+ # for word in new_text_split:
58
+ # if len(word)>=2 and word.isupper():
59
+ # new_text = new_text.replace(word, " ".join([*word]))
60
+
61
+ # wav_file = tts.synthesise(
62
+ # text=new_text,
63
+ # spk_ref_path="./tmp/audio/speaker_wav.wav" # you can use any speaker reference file (WAV, OGG, MP3, FLAC, etc.)
64
+ # )
65
+ # sf.write('audio/output.wav', wav_file, samplerate=22050)
66
+
67
+
68
+
69
+ def run_audio_generation_v2(new_text, accent=None):
70
+ # check for abbreviations in new text. need to add - after each letter so that audio comes out okay
71
  new_text = new_text.replace('\n', ' ').replace('\r', '')
72
  new_text_mod = remove_punctuation(new_text)
73
 
 
75
  for word in new_text_split:
76
  if len(word)>=2 and word.isupper():
77
  new_text = new_text.replace(word, " ".join([*word]))
78
+ print(new_text)
79
+
80
+
81
+ if len(new_text)<=220:
82
+ sampling_config = SamplingControllerConfig(spk_cond_path="./tmp/audio/input_src/0.wav", text=new_text, output_dir='./tmp/audio/', output_name='generated-custom.wav')
83
+ metavoice_gen(sampling_config)
84
+ else:
85
+ new_texts = new_text.split('. ') #textwrap.wrap(new_text, 220)
86
+ new_texts = [txt +"." for txt in new_texts]
87
+ output_names = []
88
+ for idx, new_text in enumerate(new_texts):
89
+ output_name = "-{}.".format(idx).join('generated-custom.wav'.split('.'))
90
+ output_names.append(output_name)
91
+ sampling_config = SamplingControllerConfig(spk_cond_path="./tmp/audio/input_src/0.wav", text=new_text, output_dir='./tmp/audio/multiple/', output_name=output_name)
92
+ metavoice_gen(sampling_config)
93
+
94
+ #audio_files = ['./tmp/audio/multiple/'+'/'+ aud for aud in os.listdir('./tmp/audio/multiple/') if aud.endswith(".wav")]
95
+ audio_files = ['./tmp/audio/multiple/'+'/'+ aud for aud in output_names]
96
+ print(audio_files)
97
+ clips = [(AudioFileClip(clip)) for clip in audio_files]
98
+ final_clip = concatenate_audioclips(clips)
99
+ final_clip.write_audiofile('./tmp/audio/generated-custom.wav')
100
+
101
+ # adjust loudness
102
+ data, rate = sf.read("./tmp/audio/input_audio.wav") # load audio (with shape (samples, channels))
103
+ meter = pyln.Meter(rate) # create BS.1770 meter
104
+ loudness_target = meter.integrated_loudness(data) # measure loudness
105
+
106
+ mod_data, mod_rate = sf.read("./tmp/audio/generated-custom.wav") # load audio (with shape (samples, channels))
107
+ mod_meter = pyln.Meter(mod_rate) # create BS.1770 meter
108
+ loudness_gen = mod_meter.integrated_loudness(mod_data) # measure loudness
109
+
110
+ loudness_normalized_gen = pyln.normalize.loudness(mod_data, loudness_gen, loudness_target)
111
+ sf.write('./tmp/audio/generated-custom.wav', loudness_normalized_gen, mod_rate)
112
+
113
+ @dataclass
114
+ class InferenceConfig:
115
+ ckpt_path: str # path to checkpoint
116
+ output_dir: str
117
+ num_samples: int = 10 # number of samples to draw
118
+ seed: int = 1337 # random seed
119
+ device: str = "cuda"
120
+ dtype: str = "bfloat16"
121
+ compile: bool = False
122
+ init_from: str = "resume" # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')
123
+
124
+ def __str__(self):
125
+ field_strs = []
126
+ for field in dataclasses.fields(self):
127
+ value = getattr(self, field.name)
128
+ field_strs.append(f" {field.name}: {value}")
129
+
130
+ return "InferenceConfig:\n" + "\n".join(field_strs)
131
+
132
+
133
+ class Model:
134
+ def __init__(
135
+ self,
136
+ config: InferenceConfig,
137
+ tokenizer_cls: Type[TrainedBPETokeniser],
138
+ decoder_cls: Type[Decoder],
139
+ data_adapter_fn,
140
+ use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = None,
141
+ ):
142
+ # TODO: disentangle the encodec stuff and numbers etc with rest of this code (esp at encoder-only / second stage model inference)
143
+ # TODO: remove magic number
144
+ self._encodec_codes_pad_token = 1024
145
+ self._num_encodec_codebooks = 8
146
+ self.config = config
147
+ self.use_kv_cache = use_kv_cache
148
+
149
+ torch.manual_seed(config.seed)
150
+ torch.cuda.manual_seed(config.seed)
151
+ torch.backends.cuda.matmul.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on matmul
152
+ torch.backends.cudnn.allow_tf32 = True if config.dtype != "float32" else False # allow tf32 on cudnn
153
+ device_type = "cuda" if "cuda" in config.device else "cpu" # for later use in torch.autocast
154
+ self.ptdtype = {
155
+ "float32": torch.float32,
156
+ "tfloat32": torch.float32,
157
+ "bfloat16": torch.bfloat16,
158
+ "float16": torch.float16,
159
+ }[config.dtype]
160
+ self._ctx = (
161
+ nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=self.ptdtype)
162
+ )
163
+
164
+ self.use_bpe_tokenizer = False
165
+ self.load_meta = None
166
+ self.speaker_cond = None
167
+ self.meta = None
168
+ self.model = None
169
+ self.checkpoint_config = None
170
+ self.vocab_sizes = None
171
+ self.smodel = None
172
+
173
+ self._init_model()
174
+
175
+ self.tokenizer = tokenizer_cls(**self.meta["tokenizer"])
176
+ self.decoder = decoder_cls(
177
+ tokeniser_decode_fn=self.tokenizer.decode,
178
+ output_dir=self.config.output_dir,
179
+ data_adapter_fn=data_adapter_fn,
180
+ )
181
+
182
+ def _init_model(self):
183
+ if self.config.init_from == "resume":
184
+ # init from a model saved in a specific directory
185
+ checkpoint = torch.load(self.config.ckpt_path, map_location=self.config.device)
186
+ self.vocab_sizes = checkpoint["model_args"]["vocab_sizes"]
187
+
188
+ self.load_meta = False
189
+ self.speaker_cond = False
190
+
191
+ if "config" in checkpoint:
192
+ self.checkpoint_config = checkpoint["config"]
193
+
194
+ self.meta = checkpoint["meta"]
195
+ load_meta = True
196
+
197
+ if load_meta:
198
+ self.use_bpe_tokenizer = "stoi" not in self.meta or "itos" not in self.meta
199
+ self.speaker_cond = self.meta.get("speaker_cond")
200
+
201
+ if self.speaker_cond:
202
+ speaker_emb_size = self.meta["speaker_emb_size"]
203
+
204
+ model_args = checkpoint["model_args"]
205
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
206
+ self._encodec_ctx_window = model_args["block_size"]
207
+
208
+ gptconf = GPTConfig(**model_args)
209
+
210
+ # TODO: rename `speaker_emb_dim` to `speaker_emb_size`.
211
+ self.model = GPT(gptconf, speaker_emb_dim=speaker_emb_size if self.speaker_cond else None)
212
+ state_dict = checkpoint["model"]
213
+ unwanted_prefix = "_orig_mod."
214
+ for k, v in list(state_dict.items()):
215
+ if k.startswith(unwanted_prefix):
216
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
217
+ self.model.load_state_dict(state_dict)
218
+
219
+ # model
220
+ self.model.eval()
221
+ self.model.to(self.config.device)
222
+
223
+ if self.config.compile:
224
+ from einops._torch_specific import allow_ops_in_compiled_graph
225
+
226
+ allow_ops_in_compiled_graph()
227
+ self.model = torch.compile(self.model) # type: ignore
228
+
229
+ if self.use_kv_cache is not None:
230
+ if "causal" in self.checkpoint_config and self.checkpoint_config["causal"] is False:
231
+ raise Exception("kv_cache not supported for non-causal models!")
232
+
233
+ if self.use_kv_cache == "flash_decoding":
234
+ self.model.enable_kv_cache()
235
+ for block in self.model.transformer.h:
236
+ block.attn.attn_kernel_type = "fd"
237
+ elif self.use_kv_cache == "vanilla":
238
+ self.model.enable_kv_cache()
239
+ else:
240
+ raise NotImplementedError(f"kv_cache type {self.use_kv_cache} not implemented!")
241
+
242
+ def causal_sample(
243
+ self,
244
+ *,
245
+ texts: list[str],
246
+ batch_size: int,
247
+ max_new_tokens: int,
248
+ temperature: Optional[float],
249
+ top_k: Optional[int],
250
+ top_p: Optional[float],
251
+ speaker_embs: Optional[torch.Tensor] = None,
252
+ guidance_scale: Optional[float] = None,
253
+ ) -> list[torch.Tensor]:
254
+ """
255
+ Returns list of torch.Tensors of tokens. Each tensor is of shape (1, c, t) where c is the number of codebooks.
256
+ Any flattening / inteleaving / tilting gets reversed before the output is returned.
257
+ """
258
+ if speaker_embs is not None:
259
+ assert len(texts) == len(speaker_embs)
260
+
261
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
262
+
263
+ ## create multiple hierarchies and get seq_lens
264
+ seq_lens = []
265
+ xs = []
266
+ for i, encoded_text in enumerate(encoded_texts):
267
+ encoded_text = torch.tensor([encoded_text], dtype=torch.long, device=self.config.device)
268
+ # TODO: remove magic number
269
+ xs.append(
270
+ torch.cat(
271
+ # [1st hierarchy of text, *remaining hierarchies of padded tokens]
272
+ # TODO: self.vocab_sizes should be from the model config?
273
+ [encoded_text, *[torch.ones_like(encoded_text) * 1024] * (len(self.vocab_sizes) - 1)],
274
+ dim=0,
275
+ ).unsqueeze(0)
276
+ ) # b x [(b=1, c, t)]
277
+ seq_lens.append(xs[-1].shape[-1])
278
+ max_len = max(seq_lens)
279
+ assert len(xs) == len(seq_lens)
280
+
281
+ ## equalise the shapes in the batch. we can use torch.zeros as tokens > seq_lens will be masked out.
282
+ x = torch.zeros((len(encoded_texts), xs[0].shape[1], max_len), dtype=torch.long, device=self.config.device)
283
+ for i, _xs in enumerate(xs):
284
+ assert _xs.shape[-1] == seq_lens[i]
285
+ x[i, :, : seq_lens[i]] = _xs
286
+
287
+ ## check that the input is correct
288
+ for i in range(x.shape[0]):
289
+ assert x[i, 0, : seq_lens[i]].tolist() == encoded_texts[i]
290
+
291
+ # TODO: remove magic number
292
+ if x.shape[1] > 1:
293
+ assert set(x[i, 1, : seq_lens[i]].tolist()) == set([1024])
294
+
295
+ assert x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
296
+
297
+ if self.speaker_cond is False:
298
+ speaker_embs = None
299
+
300
+ # run sampling loop
301
+ with torch.no_grad():
302
+ with self._ctx: # type: ignore
303
+ to_return = []
304
+ for k in range(self.config.num_samples):
305
+ assert seq_lens is not None
306
+ assert batch_size is not None
307
+
308
+ if max(seq_lens) + max_new_tokens >= self.model.config.block_size:
309
+ raise Exception(
310
+ f"max_new_tokens {max_new_tokens} too large! Choose {self.model.config.block_size - max(seq_lens) - 1} instead."
311
+ )
312
+
313
+ y = self.model.generate(
314
+ x,
315
+ max_new_tokens,
316
+ seq_lens=seq_lens,
317
+ temperature=temperature,
318
+ top_k=top_k,
319
+ top_p=top_p,
320
+ speaker_embs=speaker_embs,
321
+ batch_size=batch_size,
322
+ guidance_scale=guidance_scale,
323
+ dtype=self.ptdtype,
324
+ end_of_audio_token=self.tokenizer.offset - 1,
325
+ end_of_text_token=self.tokenizer.eot_token,
326
+ )
327
+ for i in range(len(y)):
328
+ to_return.append(self.decoder.decode(tokens=y[i].tolist(), causal=True))
329
+
330
+ return to_return
331
+
332
+ def non_causal_sample(
333
+ self,
334
+ *,
335
+ texts: list[str],
336
+ encodec_tokens: list[torch.Tensor],
337
+ batch_size: int,
338
+ top_k: Optional[int],
339
+ temperature: Optional[float],
340
+ speaker_embs: Optional[torch.Tensor] = None,
341
+ ) -> list[str]:
342
+ """
343
+ Returns paths to saved audio files.
344
+ """
345
+ if speaker_embs is not None:
346
+ assert len(texts) == len(speaker_embs)
347
+
348
+ encoded_texts = [self.tokenizer.encode(text) for text in texts]
349
+
350
+ # setup input
351
+ # TODO: same code is used during data prep. refactor
352
+ padded_hierarchies_inputs = []
353
+ for encoded_text, encodec_token in zip(encoded_texts, encodec_tokens):
354
+ x = torch.tensor(encoded_text, dtype=torch.long, device=self.config.device)[
355
+ None, None, ...
356
+ ] # (b=1, c=1, t)
357
+
358
+ # TODO: should only happen if decoder is encodecdeocder?
359
+ assert encodec_token.shape[0] == 1
360
+ encodec_token = encodec_token[0].tolist() # (b=1, c, t) -> (c, t)
361
+ assert len(encodec_token) >= 1 and len(encodec_token) <= self._num_encodec_codebooks
362
+
363
+ ## setup hierarchies of tokens
364
+ # TODO: refactor and merge with code in processing.py
365
+ text_tokens = encoded_text # (t,)
366
+
367
+ hierarchies_in = []
368
+ hierarchies_in.append(text_tokens + encodec_token[0] + [self._encodec_codes_pad_token])
369
+ hierarchies_in.append(
370
+ [self._encodec_codes_pad_token] * len(text_tokens) + encodec_token[1] + [self._encodec_codes_pad_token]
371
+ )
372
+
373
+ ## adding padding / cutting to the right size as needed
374
+ # TODO: refactor and merge with code in processing.py
375
+ padded_hierarchies_input = []
376
+ for _, t_hierarchy in enumerate(hierarchies_in):
377
+ assert len(t_hierarchy) == len(hierarchies_in[0])
378
+ if len(t_hierarchy) < self._encodec_ctx_window:
379
+ padded_hierarchies_input.append(
380
+ t_hierarchy + [self._encodec_codes_pad_token] * (self._encodec_ctx_window - len(t_hierarchy))
381
+ )
382
+ elif len(t_hierarchy) > self._encodec_ctx_window:
383
+ padded_hierarchies_input.append(t_hierarchy[: self._encodec_ctx_window])
384
+ else:
385
+ padded_hierarchies_input.append(t_hierarchy)
386
+
387
+ padded_hierarchies_inputs.append(padded_hierarchies_input)
388
+
389
+ ## check that the input is correct
390
+ in_x = torch.tensor(padded_hierarchies_inputs, dtype=torch.long, device=self.config.device)
391
+ assert in_x.shape[0] == speaker_embs.shape[0] if speaker_embs is not None else True
392
+
393
+ if self.speaker_cond is False:
394
+ speaker_embs = None
395
+
396
+ # run sampling loop
397
+ with torch.no_grad():
398
+ with self._ctx: # type: ignore
399
+ to_return = []
400
+ for k in range(self.config.num_samples):
401
+ y = self.model.generate(
402
+ in_x,
403
+ None,
404
+ temperature=temperature,
405
+ top_k=top_k,
406
+ # TODO: handle separate top_p for this model explicitly
407
+ top_p=None,
408
+ speaker_embs=speaker_embs,
409
+ batch_size=batch_size,
410
+ guidance_scale=None,
411
+ )
412
+
413
+ b_tokens = torch.cat([in_x, y], dim=1)
414
+ for tokens in b_tokens:
415
+ try:
416
+ to_return.append(self.decoder.decode(tokens=tokens.tolist(), causal=False))
417
+ except Exception as e:
418
+ print("failed to run MBD.")
419
+ print(f"reason: {str(e)}")
420
+ to_return.append(None)
421
+
422
+ return to_return
423
+
424
+ def __call__(
425
+ self,
426
+ *,
427
+ texts: list[str],
428
+ batch_size: int,
429
+ max_new_tokens: Optional[int],
430
+ top_k: Optional[int],
431
+ top_p: Optional[float],
432
+ temperature: Optional[float],
433
+ encodec_tokens: Optional[list[torch.Tensor]] = None,
434
+ speaker_embs: Optional[torch.Tensor] = None,
435
+ guidance_scale: Optional[float] = None,
436
+ ):
437
+ if self.checkpoint_config.get("causal", True):
438
+ return self.causal_sample(
439
+ texts=texts,
440
+ batch_size=batch_size,
441
+ speaker_embs=speaker_embs,
442
+ guidance_scale=guidance_scale,
443
+ max_new_tokens=max_new_tokens,
444
+ top_k=top_k,
445
+ top_p=top_p,
446
+ temperature=temperature,
447
+ )
448
+ else:
449
+ assert encodec_tokens is not None
450
+ assert guidance_scale is None
451
+ assert max_new_tokens is None
452
+ assert top_p is None
453
+
454
+ return self.non_causal_sample(
455
+ texts=texts,
456
+ encodec_tokens=encodec_tokens,
457
+ batch_size=batch_size,
458
+ speaker_embs=speaker_embs,
459
+ top_k=top_k,
460
+ temperature=temperature,
461
+ )
462
+
463
+
464
+ def save_result_metadata(wav_path, ref_path, text, first_stage_ckpt_path, second_stage_ckpt_path):
465
+ if first_stage_ckpt_path is None or second_stage_ckpt_path is None:
466
+ return
467
+ json.dump(
468
+ {
469
+ "speaker": ref_path,
470
+ "text": text,
471
+ },
472
+ pathlib.Path(str(wav_path) + ".json").open("w"),
473
+ )
474
+
475
+
476
+ def get_cached_file(file_or_uri: str):
477
+ """
478
+ If it's an s3 file, download it to a local temporary file and return that path.
479
+ Otherwise return the path as is.
480
+ """
481
+ is_uri = file_or_uri.startswith("http")
482
+
483
+ cache_path = None
484
+ if is_uri:
485
+ ext = pathlib.Path(file_or_uri).suffix
486
+ # hash the file path to get the cache name
487
+ _cache_name = "audio_" + hashlib.md5(file_or_uri.encode("utf-8")).hexdigest() + ext
488
+
489
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
490
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
491
+
492
+ if not os.path.exists(cache_path):
493
+ command = f"curl -o {cache_path} {file_or_uri}"
494
+ subprocess.run(command, shell=True, check=True)
495
+ else:
496
+ if os.path.exists(file_or_uri):
497
+ cache_path = file_or_uri
498
+ else:
499
+ raise FileNotFoundError(f"File {file_or_uri} not found!")
500
+ return cache_path
501
+
502
+
503
+ def get_cached_embedding(local_file_path: str, spkemb_model):
504
+ if not os.path.exists(local_file_path):
505
+ raise FileNotFoundError(f"File {local_file_path} not found!")
506
+
507
+ # hash the file path to get the cache name
508
+ _cache_name = "embedding_" + hashlib.md5(local_file_path.encode("utf-8")).hexdigest() + ".pt"
509
+
510
+ os.makedirs(os.path.expanduser("~/.cache/fam/"), exist_ok=True)
511
+ cache_path = os.path.expanduser(f"~/.cache/fam/{_cache_name}")
512
+
513
+ if not os.path.exists(cache_path):
514
+ spk_emb = spkemb_model.embed_utterance_from_file(local_file_path, numpy=False).unsqueeze(0) # (b=1, c)
515
+ torch.save(spk_emb, cache_path)
516
+ else:
517
+ spk_emb = torch.load(cache_path)
518
+
519
+ return spk_emb
520
+
521
+
522
+ def _sample_utterance_batch(
523
+ texts: list[str],
524
+ spk_cond_paths: list[Optional[str]],
525
+ spkemb_model,
526
+ first_stage_model,
527
+ second_stage_model,
528
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
529
+ first_stage_ckpt_path: str,
530
+ second_stage_ckpt_path: str,
531
+ guidance_scale: Optional[Tuple[float, float]],
532
+ max_new_tokens: int,
533
+ top_k: Optional[int],
534
+ top_p: Optional[float],
535
+ temperature: Optional[float],
536
+ output_name: str,
537
+ output_dir: str,
538
+ batch_size: int = 128,
539
+ ) -> List[str]:
540
+
541
+ speaker_embs = []
542
+ refs = spk_cond_paths.copy()
543
+
544
+ # multithreaded loop to cache all the files
545
+ spk_cond_paths = tqdm.contrib.concurrent.thread_map(
546
+ get_cached_file, spk_cond_paths, desc="getting cached speaker ref files"
547
+ )
548
+
549
+ for i, (text, spk_cond_path) in tqdm.tqdm(
550
+ enumerate(zip(texts, spk_cond_paths)), total=len(texts), desc="calculating speaker embeddings"
551
+ ):
552
+ texts[i] = normalize_text(text)
553
+ speaker_embs.append(get_cached_embedding(spk_cond_path, spkemb_model) if spk_cond_path else None)
554
+
555
+ b_speaker_embs = torch.cat(speaker_embs, dim=0)
556
+ b_tokens = first_stage_model(
557
+ texts=texts,
558
+ speaker_embs=b_speaker_embs,
559
+ batch_size=batch_size,
560
+ guidance_scale=guidance_scale,
561
+ top_p=top_p,
562
+ top_k=top_k,
563
+ temperature=temperature,
564
+ max_new_tokens=max_new_tokens,
565
+ )
566
+
567
+ # TODO: set batch size for second stage model!
568
+ wav_files = second_stage_model(
569
+ texts=texts,
570
+ encodec_tokens=b_tokens,
571
+ speaker_embs=b_speaker_embs,
572
+ batch_size=batch_size,
573
+ guidance_scale=None,
574
+ top_p=None,
575
+ top_k=top_k,
576
+ temperature=temperature,
577
+ max_new_tokens=None,
578
+ )
579
+
580
+ for text, tokens, speaker_embs, ref_name, wav_file in zip(texts, b_tokens, b_speaker_embs, refs, wav_files):
581
+ if wav_file is None:
582
+ continue
583
+
584
+ with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
585
+ if enhancer is not None:
586
+ enhancer = get_enhancer(enhancer) if isinstance(enhancer, str) else enhancer
587
+ enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
588
+ # copy enhanced_tmp.name back to wav_file
589
+ print(f"copying enhanced file from {enhanced_tmp.name} to {str(wav_file) + '.wav'}.")
590
+ shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
591
+ shutil.copy2(str(wav_file) + ".wav", os.path.join(output_dir, output_name))
592
+
593
+ save_result_metadata(
594
+ wav_file,
595
+ ref_name,
596
+ text,
597
+ first_stage_ckpt_path,
598
+ second_stage_ckpt_path,
599
+ )
600
+ return [str(w) + ".wav" if not str(w).endswith(".wav") else str(w) for w in wav_files]
601
+
602
+
603
+ def sample_utterance(
604
+ text: str,
605
+ spk_cond_path: Optional[str],
606
+ spkemb_model,
607
+ first_stage_model,
608
+ second_stage_model,
609
+ enhancer: Optional[Union[Literal["df"], BaseEnhancer]],
610
+ first_stage_ckpt_path: str,
611
+ second_stage_ckpt_path: str,
612
+ guidance_scale: Optional[Tuple[float, float]],
613
+ max_new_tokens: int,
614
+ top_k: Optional[int],
615
+ top_p: Optional[float],
616
+ temperature: Optional[float],
617
+ output_name: str,
618
+ output_dir: str,
619
+ ) -> str:
620
+ # NOTE: supports max. 220 characters atm.
621
+ # Long form synthesis coming soon...
622
+ MAX_CHARS = 220
623
+ if len(text) > MAX_CHARS:
624
+ print(
625
+ f"\n***WARNING: Max {MAX_CHARS} characters supported. Provided: {len(text)}. Truncating and generating speech...Can lead to unpredictable speech at the end.***"
626
+ )
627
+
628
+ return _sample_utterance_batch(
629
+ texts=[text],
630
+ spk_cond_paths=[spk_cond_path],
631
+ spkemb_model=spkemb_model,
632
+ first_stage_model=first_stage_model,
633
+ second_stage_model=second_stage_model,
634
+ enhancer=enhancer,
635
+ first_stage_ckpt_path=first_stage_ckpt_path,
636
+ second_stage_ckpt_path=second_stage_ckpt_path,
637
+ batch_size=1,
638
+ guidance_scale=guidance_scale,
639
+ max_new_tokens=max_new_tokens,
640
+ top_k=top_k,
641
+ top_p=top_p,
642
+ temperature=temperature,
643
+ output_name = output_name,
644
+ output_dir = output_dir
645
+ )[0]
646
+
647
+
648
+ def build_models(config_first_stage, config_second_stage, model_dir, device, use_kv_cache):
649
+ smodel = SpeakerEncoder(
650
+ weights_fpath=os.path.join(model_dir, "speaker_encoder.pt"), device=device, eval=True, verbose=False
651
+ )
652
+ data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
653
+ llm_first_stage = Model(
654
+ config_first_stage,
655
+ TrainedBPETokeniser,
656
+ EncodecDecoder,
657
+ data_adapter_fn=data_adapter.decode,
658
+ use_kv_cache=use_kv_cache,
659
+ )
660
+ data_adapter_second_stage = TiltedEncodec(end_of_audio_token=1024)
661
+ llm_second_stage = Model(
662
+ config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
663
+ )
664
+ return smodel, llm_first_stage, llm_second_stage
665
+
666
+
667
+ def get_first_stage_path(model_dir: str):
668
+ """Absolute path to checkpoint for the first stage model."""
669
+ return os.path.join(os.path.expanduser(model_dir), "first_stage.pt")
670
+
671
+
672
+ def get_second_stage_path(model_dir: str):
673
+ """Absolute path to checkpoint for the second stage model."""
674
+ return os.path.join(os.path.expanduser(model_dir), "second_stage.pt")
675
+
676
+
677
+ @dataclass
678
+ class SamplingControllerConfig:
679
+ """
680
+ Sample from a trained model.
681
+ """
682
+
683
+ spk_cond_path: str
684
+ """Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3"""
685
+
686
+ huggingface_repo_id: str = "metavoiceio/metavoice-1B-v0.1"
687
+ """Absolute path to the model directory."""
688
+
689
+ text: str = (
690
+ "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model by MetaVoice."
691
+ )
692
+ """Text to synthesise."""
693
+
694
+ num_samples: int = 1
695
+ """Number of samples to generate from each model."""
696
+
697
+ max_new_tokens: int = 864 * 2
698
+ """Maximum number of new tokens to generate from the first stage model."""
699
+
700
+ temperature: float = 1.0
701
+ """Temperature for sampling applied to both models."""
702
+
703
+ top_k: Optional[int] = 200
704
+ """Top k for sampling applied to both models."""
705
+
706
+ top_p: Optional[float] = None
707
+ """Top p for sampling applied to first-stage model."""
708
+
709
+ seed: int = 1337
710
+ """Random seed for sampling."""
711
+
712
+ device: Literal["cuda", "cpu"] = "cuda"
713
+ """Device to use for sampling."""
714
+
715
+ dtype: Literal["bfloat16", "float16", "float32", "tfloat32"] = get_default_dtype()
716
+ """Data type to use for sampling."""
717
+
718
+ compile: bool = False
719
+ """Whether to compile the model using PyTorch 2.0."""
720
+
721
+ enhancer: Optional[Literal["df"]] = "df"
722
+ """Enhancer to use for post-processing."""
723
+
724
+ init_from: str = "resume"
725
+ """Either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl')."""
726
+
727
+ use_kv_cache: Optional[Literal["flash_decoding", "vanilla"]] = get_default_use_kv_cache()
728
+ """Type of kv caching to use for inference: 1) [none] no kv caching, 2) [flash_decoding] use the
729
+ flash decoding kernel, 3) [vanilla] use torch attention with hand implemented kv-cache."""
730
+
731
+ output_dir: str = "samples/"
732
+ """Relative path to output directory"""
733
+
734
+ guidance_scale: Optional[Tuple[float, float]] = (3.0, 1.0)
735
+ """Guidance scale for sampling: (speaker conditioning guidance_scale, prompt conditioning guidance scale)."""
736
+
737
+ batch_size: int = 128
738
+ """Batch size to use for sampling. Note that the batch size gets doubled when guidance is used. For H100, and 1B model,
739
+ 1 w/ guidance and 1 w/o guidance work well (without kv-caching). With kv-caching, 128 (w/o guidance) and
740
+ 64 (w/ guidance) works well."""
741
+
742
+ output_name:str = "generated-custom.wav"
743
+
744
+ def metavoice_gen(sampling_config):
745
 
746
+ sampling_config = sampling_config #tyro.cli(SamplingControllerConfig, use_underscores=True)
747
+
748
+ check_audio_file(sampling_config.spk_cond_path)
749
+
750
+ model_dir = snapshot_download(repo_id=sampling_config.huggingface_repo_id)
751
+ first_stage_ckpt_path = get_first_stage_path(model_dir)
752
+ second_stage_ckpt_path = get_second_stage_path(model_dir)
753
+
754
+ config_first_stage = InferenceConfig(
755
+ ckpt_path=first_stage_ckpt_path,
756
+ num_samples=sampling_config.num_samples,
757
+ seed=sampling_config.seed,
758
+ device=sampling_config.device,
759
+ dtype=sampling_config.dtype,
760
+ compile=sampling_config.compile,
761
+ init_from=sampling_config.init_from,
762
+ output_dir=sampling_config.output_dir,
763
+ )
764
+
765
+ config_second_stage = InferenceConfig(
766
+ ckpt_path=second_stage_ckpt_path,
767
+ num_samples=sampling_config.num_samples,
768
+ seed=sampling_config.seed,
769
+ device=sampling_config.device,
770
+ dtype=sampling_config.dtype,
771
+ compile=sampling_config.compile,
772
+ init_from=sampling_config.init_from,
773
+ output_dir=sampling_config.output_dir,
774
+ )
775
+
776
+ # sampling_config.max_new_tokens *= (
777
+ # 2 # deal with max_new_tokens for flattened interleaving! (should scale with num_codebooks?)
778
+ # )
779
+
780
+ # define models
781
+ smodel, llm_first_stage, llm_second_stage = build_models(
782
+ config_first_stage,
783
+ config_second_stage,
784
+ model_dir=model_dir,
785
+ device=sampling_config.device,
786
+ use_kv_cache=sampling_config.use_kv_cache,
787
+ )
788
+
789
+ sample_utterance(
790
+ sampling_config.text,
791
+ os.path.expanduser(sampling_config.spk_cond_path),
792
+ smodel,
793
+ llm_first_stage,
794
+ llm_second_stage,
795
+ sampling_config.enhancer,
796
+ first_stage_ckpt_path,
797
+ second_stage_ckpt_path,
798
+ sampling_config.guidance_scale,
799
+ max_new_tokens=sampling_config.max_new_tokens,
800
+ top_k=sampling_config.top_k,
801
+ top_p=sampling_config.top_p,
802
+ temperature=sampling_config.temperature,
803
+ output_name = sampling_config.output_name,
804
+ output_dir=sampling_config.output_dir,
805
+ )
806
+