|
import os |
|
import re |
|
|
|
import torch |
|
import torchaudio |
|
from einops import rearrange |
|
from vocos import Vocos |
|
|
|
from model import CFM, UNetT, DiT, MMDiT |
|
from model.utils import ( |
|
load_checkpoint, |
|
get_tokenizer, |
|
convert_char_to_pinyin, |
|
save_spectrogram, |
|
) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
|
|
|
|
|
|
|
target_sample_rate = 24000 |
|
n_mel_channels = 100 |
|
hop_length = 256 |
|
target_rms = 0.1 |
|
|
|
tokenizer = "pinyin" |
|
dataset_name = "Emilia_ZH_EN" |
|
|
|
|
|
|
|
|
|
seed = None |
|
|
|
exp_name = "F5TTS_Base" |
|
ckpt_step = 1200000 |
|
|
|
nfe_step = 32 |
|
cfg_strength = 2. |
|
ode_method = 'euler' |
|
sway_sampling_coef = -1. |
|
speed = 0.8 |
|
fix_duration = 27 |
|
|
|
if exp_name == "F5TTS_Base": |
|
model_cls = DiT |
|
model_cfg = dict(dim = 1024, depth = 22, heads = 16, ff_mult = 2, text_dim = 512, conv_layers = 4) |
|
|
|
elif exp_name == "E2TTS_Base": |
|
model_cls = UNetT |
|
model_cfg = dict(dim = 1024, depth = 24, heads = 16, ff_mult = 4) |
|
|
|
ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" |
|
output_dir = "tests" |
|
|
|
ref_audio = "tests/ref_audio/rashmika_input.wav" |
|
ref_text = "" |
|
|
|
|
|
|
|
gen_text_ = "Happy Birthday, Dhillip Kumar. Virat Kohli this side, all the best for your future endeavours! " |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
use_ema = True |
|
|
|
if not os.path.exists(output_dir): |
|
os.makedirs(output_dir) |
|
|
|
|
|
local = False |
|
if local: |
|
vocos_local_path = "../checkpoints/charactr/vocos-mel-24khz" |
|
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml") |
|
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device) |
|
vocos.load_state_dict(state_dict) |
|
vocos.eval() |
|
else: |
|
vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz") |
|
|
|
|
|
vocab_char_map, vocab_size = get_tokenizer(dataset_name, tokenizer) |
|
|
|
|
|
model = CFM( |
|
transformer = model_cls( |
|
**model_cfg, |
|
text_num_embeds = vocab_size, |
|
mel_dim = n_mel_channels |
|
), |
|
mel_spec_kwargs = dict( |
|
target_sample_rate = target_sample_rate, |
|
n_mel_channels = n_mel_channels, |
|
hop_length = hop_length, |
|
), |
|
odeint_kwargs = dict( |
|
method = ode_method, |
|
), |
|
vocab_char_map = vocab_char_map, |
|
).to(device) |
|
|
|
model = load_checkpoint(model, ckpt_path, device, use_ema = use_ema) |
|
|
|
|
|
audio, sr = torchaudio.load(ref_audio) |
|
if audio.shape[0] > 1: |
|
audio = torch.mean(audio, dim=0, keepdim=True) |
|
rms = torch.sqrt(torch.mean(torch.square(audio))) |
|
if rms < target_rms: |
|
audio = audio * target_rms / rms |
|
if sr != target_sample_rate: |
|
resampler = torchaudio.transforms.Resample(sr, target_sample_rate) |
|
audio = resampler(audio) |
|
audio = audio.to(device) |
|
|
|
|
|
text_list = [ref_text + gen_text] |
|
if tokenizer == "pinyin": |
|
final_text_list = convert_char_to_pinyin(text_list) |
|
else: |
|
final_text_list = [text_list] |
|
print(f"text : {text_list}") |
|
print(f"pinyin: {final_text_list}") |
|
|
|
|
|
ref_audio_len = audio.shape[-1] // hop_length |
|
if fix_duration is not None: |
|
duration = int(fix_duration * target_sample_rate / hop_length) |
|
else: |
|
zh_pause_punc = r"。,、;:?!" |
|
ref_text_len = len(ref_text) + len(re.findall(zh_pause_punc, ref_text)) |
|
gen_text_len = len(gen_text) + len(re.findall(zh_pause_punc, gen_text)) |
|
duration = ref_audio_len + int(ref_audio_len / ref_text_len * gen_text_len / speed) |
|
|
|
|
|
with torch.inference_mode(): |
|
generated, trajectory = model.sample( |
|
cond = audio, |
|
text = final_text_list, |
|
duration = duration, |
|
steps = nfe_step, |
|
cfg_strength = cfg_strength, |
|
sway_sampling_coef = sway_sampling_coef, |
|
seed = seed, |
|
) |
|
print(f"Generated mel: {generated.shape}") |
|
|
|
|
|
generated = generated[:, ref_audio_len:, :] |
|
generated_mel_spec = rearrange(generated, '1 n d -> 1 d n') |
|
generated_wave = vocos.decode(generated_mel_spec.cpu()) |
|
if rms < target_rms: |
|
generated_wave = generated_wave * rms / target_rms |
|
|
|
save_spectrogram(generated_mel_spec[0].cpu().numpy(), f"{output_dir}/test_single_dbday.png") |
|
torchaudio.save(f"{output_dir}/test_single_dbday.wav", generated_wave, target_sample_rate) |
|
print(f"Generated wav: {generated_wave.shape}") |
|
|