emo-knob / fam /llm /fast_inference.py
tonychenxyz's picture
init
9e34a62
raw
history blame
5.58 kB
import os
import shutil
import tempfile
import time
from pathlib import Path
import librosa
import torch
from huggingface_hub import snapshot_download
from fam.llm.adapters import FlattenedInterleavedEncodec2Codebook
from fam.llm.decoders import EncodecDecoder
from fam.llm.fast_inference_utils import build_model, main
from fam.llm.inference import (
EncodecDecoder,
InferenceConfig,
Model,
TiltedEncodec,
TrainedBPETokeniser,
get_cached_embedding,
get_cached_file,
get_enhancer,
)
from fam.llm.utils import (
check_audio_file,
get_default_dtype,
get_device,
normalize_text,
)
class TTS:
END_OF_AUDIO_TOKEN = 1024
def __init__(
self, model_name: str = "metavoiceio/metavoice-1B-v0.1", *, seed: int = 1337, output_dir: str = "outputs"
):
"""
model_name (str): refers to the model identifier from the Hugging Face Model Hub (https://huggingface.co/metavoiceio)
"""
# NOTE: this needs to come first so that we don't change global state when we want to use
# the torch.compiled-model.
self._dtype = get_default_dtype()
self._device = get_device()
self._model_dir = snapshot_download(repo_id=model_name, cache_dir = '/proj/afosr/metavoice/cache')
self.first_stage_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
second_stage_ckpt_path = f"{self._model_dir}/second_stage.pt"
config_second_stage = InferenceConfig(
ckpt_path=second_stage_ckpt_path,
num_samples=1,
seed=seed,
device=self._device,
dtype=self._dtype,
compile=False,
init_from="resume",
output_dir=self.output_dir,
)
data_adapter_second_stage = TiltedEncodec(end_of_audio_token=self.END_OF_AUDIO_TOKEN)
self.llm_second_stage = Model(
config_second_stage, TrainedBPETokeniser, EncodecDecoder, data_adapter_fn=data_adapter_second_stage.decode
)
self.enhancer = get_enhancer("df")
self.precision = {"float16": torch.float16, "bfloat16": torch.bfloat16}[self._dtype]
self.model, self.tokenizer, self.smodel, self.model_size = build_model(
precision=self.precision,
checkpoint_path=Path(f"{self._model_dir}/first_stage.pt"),
spk_emb_ckpt_path=Path(f"{self._model_dir}/speaker_encoder.pt"),
device=self._device,
compile=True,
compile_prefill=True,
)
def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
text: Text to speak
spk_ref_path: Path to speaker reference file. Min. 30s of audio required. Supports both local paths & public URIs. Audio formats: wav, flac & mp3
top_p: Top p for sampling applied to first-stage model. Range [0.9, 1.0] are good. This is a measure of speech stability - improves text following for a challenging speaker
guidance_scale: Guidance scale [1.0, 3.0] for sampling. This is a measure of speaker similarity - how closely to match speaker identity and speech style.
temperature: Temperature for sampling applied to both LLMs (first & second stage)
returns: path to speech .wav file
"""
text = normalize_text(text)
spk_ref_path = get_cached_file(spk_ref_path)
check_audio_file(spk_ref_path)
spk_emb = get_cached_embedding(
spk_ref_path,
self.smodel,
).to(device=self._device, dtype=self.precision)
start = time.time()
# first stage LLM
tokens = main(
model=self.model,
tokenizer=self.tokenizer,
model_size=self.model_size,
prompt=text,
spk_emb=spk_emb,
top_p=torch.tensor(top_p, device=self._device, dtype=self.precision),
guidance_scale=torch.tensor(guidance_scale, device=self._device, dtype=self.precision),
temperature=torch.tensor(temperature, device=self._device, dtype=self.precision),
)
_, extracted_audio_ids = self.first_stage_adapter.decode([tokens])
b_speaker_embs = spk_emb.unsqueeze(0)
# second stage LLM + multi-band diffusion model
wav_files = self.llm_second_stage(
texts=[text],
encodec_tokens=[torch.tensor(extracted_audio_ids, dtype=torch.int32, device=self._device).unsqueeze(0)],
speaker_embs=b_speaker_embs,
batch_size=1,
guidance_scale=None,
top_p=None,
top_k=200,
temperature=1.0,
max_new_tokens=None,
)
# enhance using deepfilternet
wav_file = wav_files[0]
with tempfile.NamedTemporaryFile(suffix=".wav") as enhanced_tmp:
self.enhancer(str(wav_file) + ".wav", enhanced_tmp.name)
shutil.copy2(enhanced_tmp.name, str(wav_file) + ".wav")
print(f"\nSaved audio to {wav_file}.wav")
# calculating real-time factor (RTF)
time_to_synth_s = time.time() - start
audio, sr = librosa.load(str(wav_file) + ".wav")
duration_s = librosa.get_duration(y=audio, sr=sr)
print(f"\nTotal time to synth (s): {time_to_synth_s}")
print(f"Real-time factor: {time_to_synth_s / duration_s:.2f}")
return str(wav_file) + ".wav"
if __name__ == "__main__":
tts = TTS()