ChatTTS-Forge / modules /generate_audio.py
zhzluke96
update
29536f1
raw
history blame
3.07 kB
import numpy as np
import torch
from modules.speaker import Speaker
from modules.utils.SeedContext import SeedContext
from modules import models, config
import logging
logger = logging.getLogger(__name__)
@torch.inference_mode()
def generate_audio(
text: str,
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: int | Speaker = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
(sample_rate, wav) = generate_audio_batch(
[text],
temperature=temperature,
top_P=top_P,
top_K=top_K,
spk=spk,
infer_seed=infer_seed,
use_decoder=use_decoder,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)[0]
return (sample_rate, wav)
@torch.inference_mode()
def generate_audio_batch(
texts: list[str],
temperature: float = 0.3,
top_P: float = 0.7,
top_K: float = 20,
spk: int | Speaker = -1,
infer_seed: int = -1,
use_decoder: bool = True,
prompt1: str = "",
prompt2: str = "",
prefix: str = "",
):
chat_tts = models.load_chat_tts()
params_infer_code = {
"spk_emb": None,
"temperature": temperature,
"top_P": top_P,
"top_K": top_K,
"prompt1": prompt1 or "",
"prompt2": prompt2 or "",
"prefix": prefix or "",
"repetition_penalty": 1.0,
"disable_tqdm": config.disable_tqdm,
}
if isinstance(spk, int):
with SeedContext(spk):
params_infer_code["spk_emb"] = chat_tts.sample_random_speaker()
logger.info(("spk", spk))
elif isinstance(spk, Speaker):
params_infer_code["spk_emb"] = spk.emb
logger.info(("spk", spk.name))
else:
raise ValueError("spk must be int or Speaker")
logger.info(
{
"text": texts,
"infer_seed": infer_seed,
"temperature": temperature,
"top_P": top_P,
"top_K": top_K,
"prompt1": prompt1 or "",
"prompt2": prompt2 or "",
"prefix": prefix or "",
}
)
with SeedContext(infer_seed):
wavs = chat_tts.generate_audio(
texts, params_infer_code, use_decoder=use_decoder
)
sample_rate = 24000
return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
if __name__ == "__main__":
import soundfile as sf
# 测试batch生成
inputs = ["你好[lbreak]", "再见[lbreak]", "长度不同的文本片段[lbreak]"]
outputs = generate_audio_batch(inputs, spk=5, infer_seed=42)
for i, (sample_rate, wav) in enumerate(outputs):
print(i, sample_rate, wav.shape)
sf.write(f"batch_{i}.wav", wav, sample_rate, format="wav")
# 单独生成
for i, text in enumerate(inputs):
sample_rate, wav = generate_audio(text, spk=5, infer_seed=42)
print(i, sample_rate, wav.shape)
sf.write(f"one_{i}.wav", wav, sample_rate, format="wav")