import torch import json import re import unicodedata import numpy as np import regex from types import SimpleNamespace from LOAD.models import DurationNet, SynthesizerTrn import scipy.io.wavfile as wav config_file = "config.json" duration_model_path = "duration_model.pth" lightspeed_model_path = "gen_630k.pth" phone_set_file = "phone_set.json" device = "cuda" if torch.cuda.is_available() else "cpu" with open(config_file, "rb") as f: hps = json.load(f, object_hook=lambda x: SimpleNamespace(**x)) # Load phone set json file with open(phone_set_file, "r") as f: phone_set = json.load(f) assert phone_set[0][1:-1] == "SEP" assert "sil" in phone_set sil_idx = phone_set.index("sil") space_re = regex.compile(r"\s+") number_re = regex.compile("([0-9]+)") digits = ["không", "một", "hai", "ba", "bốn", "năm", "sáu", "bảy", "tám", "chín"] num_re = regex.compile(r"([0-9.,]*[0-9])") alphabet = "aàáảãạăằắẳẵặâầấẩẫậeèéẻẽẹêềếểễệiìíỉĩịoòóỏõọôồốổỗộơờớởỡợuùúủũụưừứửữựyỳýỷỹỵbcdđghklmnpqrstvx" keep_text_and_num_re = regex.compile(rf"[^\s{alphabet}.,0-9]") keep_text_re = regex.compile(rf"[^\s{alphabet}]") def read_number(num: str) -> str: if len(num) == 1: return digits[int(num)] elif len(num) == 2 and num.isdigit(): n = int(num) end = digits[n % 10] if n == 10: return "mười" if n % 10 == 5: end = "lăm" if n % 10 == 0: return digits[n // 10] + " mươi" elif n < 20: return "mười " + end else: if n % 10 == 1: end = "mốt" return digits[n // 10] + " mươi " + end elif len(num) == 3 and num.isdigit(): n = int(num) if n % 100 == 0: return digits[n // 100] + " trăm" elif num[1] == "0": return digits[n // 100] + " trăm lẻ " + digits[n % 100] else: return digits[n // 100] + " trăm " + read_number(num[1:]) elif len(num) >= 4 and len(num) <= 6 and num.isdigit(): n = int(num) n1 = n // 1000 return read_number(str(n1)) + " ngàn " + read_number(num[-3:]) elif "," in num: n1, n2 = num.split(",") return read_number(n1) + " phẩy " + read_number(n2) elif "." in num: parts = num.split(".") if len(parts) == 2: if parts[1] == "000": return read_number(parts[0]) + " ngàn" elif parts[1].startswith("00"): end = digits[int(parts[1][2:])] return read_number(parts[0]) + " ngàn lẻ " + end else: return read_number(parts[0]) + " ngàn " + read_number(parts[1]) elif len(parts) == 3: return ( read_number(parts[0]) + " triệu " + read_number(parts[1]) + " ngàn " + read_number(parts[2]) ) return num def text_to_phone_idx(text): # lowercase text = text.lower() # unicode normalize text = unicodedata.normalize("NFKC", text) text = text.replace(".", " . ") text = text.replace(",", " , ") text = text.replace(";", " ; ") text = text.replace(":", " : ") text = text.replace("!", " ! ") text = text.replace("?", " ? ") text = text.replace("(", " ( ") text = num_re.sub(r" \1 ", text) words = text.split() words = [read_number(w) if num_re.fullmatch(w) else w for w in words] text = " ".join(words) # remove redundant spaces text = re.sub(r"\s+", " ", text) # remove leading and trailing spaces text = text.strip() # convert words to phone indices tokens = [] for c in text: # if c is "," or ".", add phone if c in ":,.!?;(": tokens.append(sil_idx) elif c in phone_set: tokens.append(phone_set.index(c)) elif c == " ": # add phone tokens.append(0) if tokens[0] != sil_idx: # insert phone at the beginning tokens = [sil_idx, 0] + tokens if tokens[-1] != sil_idx: tokens = tokens + [0, sil_idx] return tokens def text_to_speech(duration_net, generator, text): # prevent too long text # if len(text) > 500: # text = text[:500] phone_idx = text_to_phone_idx(text) batch = { "phone_idx": np.array([phone_idx]), "phone_length": np.array([len(phone_idx)]), } # predict phoneme duration phone_length = torch.from_numpy(batch["phone_length"].copy()).long().to(device) phone_idx = torch.from_numpy(batch["phone_idx"].copy()).long().to(device) with torch.inference_mode(): phone_duration = duration_net(phone_idx, phone_length)[:, :, 0] * 1000 phone_duration = torch.where( phone_idx == sil_idx, torch.clamp_min(phone_duration, 200), phone_duration ) phone_duration = torch.where(phone_idx == 0, 0, phone_duration) # generate waveform end_time = torch.cumsum(phone_duration, dim=-1) start_time = end_time - phone_duration start_frame = start_time / 1000 * hps.data.sampling_rate / hps.data.hop_length end_frame = end_time / 1000 * hps.data.sampling_rate / hps.data.hop_length spec_length = end_frame.max(dim=-1).values pos = torch.arange(0, spec_length.item(), device=device) attn = torch.logical_and( pos[None, :, None] >= start_frame[:, None, :], pos[None, :, None] < end_frame[:, None, :], ).float() with torch.inference_mode(): y_hat = generator.infer( phone_idx, phone_length, spec_length, attn, max_len=None, noise_scale=0.0 )[0] wave = y_hat[0, 0].data.cpu().numpy() return (wave * (2**15)).astype(np.int16) def load_models(): duration_net = DurationNet(hps.data.vocab_size, 64, 4).to(device) duration_net.load_state_dict(torch.load(duration_model_path, map_location=device)) duration_net = duration_net.eval() generator = SynthesizerTrn( hps.data.vocab_size, hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **vars(hps.model), ).to(device) del generator.enc_q ckpt = torch.load(lightspeed_model_path, map_location=device) params = {} for k, v in ckpt["net_g"].items(): k = k[7:] if k.startswith("module.") else k params[k] = v generator.load_state_dict(params, strict=False) del ckpt, params generator = generator.eval() return duration_net, generator def speak(text): # Assuming load_models returns duration_net and generator duration_net, generator = load_models() paragraphs = text.split("\n") clips = [] # List to store audio clips max_chunk_length = 400 # Maximum number of characters in each chunk for paragraph in paragraphs: paragraph = paragraph.strip() if paragraph == "": continue # Split the paragraph into chunks of maximum length max_chunk_length chunks = [ paragraph[i : i + max_chunk_length] for i in range(0, len(paragraph), max_chunk_length) ] for chunk in chunks: clips.append(text_to_speech(duration_net, generator, chunk)) # Assuming text_to_speech converts text to audio clip using the models # Append silence if needed # clips.append(silence) # Concatenate all audio clips into one y = np.concatenate(clips) # Assuming hps.data.sampling_rate is defined somewhere return hps.data.sampling_rate, y def textToMp3(text, outWAV): sampling_rate, audio = speak(text) # Save the audio data to a WAV file wav.write(outWAV, sampling_rate, audio) textToMp3('bây giờ là mấy giờ', 'test.wav')