Spaces:
Runtime error
Runtime error
# text to semantic | |
import argparse | |
import os | |
import re | |
import time | |
from pathlib import Path | |
import librosa | |
import numpy as np | |
import torch | |
import whisper | |
from AR.models.t2s_lightning_module import Text2SemanticLightningModule | |
from AR.text_processing.phonemizer import GruutPhonemizer | |
from AR.utils.io import load_yaml_config | |
def get_batch(text, phonemizer): | |
# phoneme_ids 和 phoneme_ids_len 是需要的 | |
phoneme = phonemizer.phonemize(text, espeak=False) | |
phoneme_ids = phonemizer.transform(phoneme) | |
phoneme_ids_len = len(phoneme_ids) | |
phoneme_ids = np.array(phoneme_ids) | |
# add batch axis here | |
phoneme_ids = torch.tensor(phoneme_ids).unsqueeze(0) | |
phoneme_ids_len = torch.tensor([phoneme_ids_len]) | |
print("phoneme:", phoneme) | |
batch = { | |
# torch.Tensor (B, max_phoneme_length) | |
"phoneme_ids": phoneme_ids, | |
# torch.Tensor (B) | |
"phoneme_ids_len": phoneme_ids_len | |
} | |
return batch | |
def get_prompt(prompt_wav_path, asr_model, phonemizer, semantic_tokenizer): | |
sample_rate = 16000 | |
# to get prompt | |
prompt_name = os.path.basename(prompt_wav_path).split('.')[0] | |
wav, _ = librosa.load(prompt_wav_path, sr=sample_rate) | |
# 取末尾 3s, 但是不包含最后 0.1s 防止 AR S1 infer 提前停止 | |
wav = wav[-sample_rate * 3:-int(sample_rate * 0.1)] | |
# wav 需要挪出末尾的静音否则也可能提前停住 | |
prompt_text = asr_model.transcribe(wav)["text"] | |
# 移除最后的句点, 防止 AR S1 infer 提前停止, 加了句点可能会有停顿 | |
prompt_text = prompt_text.replace(".", "") | |
prompt_phoneme = phonemizer.phonemize(prompt_text, espeak=False) | |
prompt_phoneme_ids = phonemizer.transform(prompt_phoneme) | |
prompt_phoneme_ids_len = len(prompt_phoneme_ids) | |
# get prompt_semantic | |
# (T) -> (1, T) | |
wav = torch.tensor(wav).unsqueeze(0) | |
wav = wav.cuda() | |
# (1, T) | |
prompt_semantic_tokens = semantic_tokenizer.tokenize(wav).to(torch.int32) | |
prompt_phoneme_ids = torch.tensor(prompt_phoneme_ids).unsqueeze(0) | |
prompt_phoneme_ids_len = torch.tensor([prompt_phoneme_ids_len]) | |
result = { | |
'prompt_name': prompt_name, | |
'prompt_phoneme_ids': prompt_phoneme_ids, | |
'prompt_semantic_tokens': prompt_semantic_tokens, | |
'prompt_phoneme_ids_len': prompt_phoneme_ids_len | |
} | |
return result | |
def parse_args(): | |
# parse args and config | |
parser = argparse.ArgumentParser( | |
description="Run SoundStorm AR S1 model for input text file") | |
parser.add_argument( | |
'--config_file', | |
type=str, | |
default='conf/default.yaml', | |
help='path of config file') | |
parser.add_argument( | |
"--text_file", | |
type=str, | |
help="text file to be convert to semantic tokens, a 'utt_id sentence' pair per line." | |
) | |
parser.add_argument( | |
'--ckpt_path', | |
type=str, | |
default='exp/default/ckpt/epoch=99-step=49000.ckpt', | |
help='Checkpoint file of SoundStorm AR S1 model.') | |
parser.add_argument( | |
'--prompt_wav_path', | |
type=str, | |
default=None, | |
help='extract prompt semantic and prompt phonemes from prompt wav') | |
# to get semantic tokens from prompt_wav | |
parser.add_argument("--hubert_path", type=str, default=None) | |
parser.add_argument("--quantizer_path", type=str, default=None) | |
parser.add_argument("--output_dir", type=str, help="output dir.") | |
args = parser.parse_args() | |
return args | |
def main(): | |
args = parse_args() | |
config = load_yaml_config(args.config_file) | |
output_dir = Path(args.output_dir) | |
output_dir.mkdir(parents=True, exist_ok=True) | |
hz = 50 | |
max_sec = config['data']['max_sec'] | |
# get models | |
t2s_model = Text2SemanticLightningModule.load_from_checkpoint( | |
checkpoint_path=args.ckpt_path, config=config) | |
t2s_model.cuda() | |
t2s_model.eval() | |
phonemizer: GruutPhonemizer = GruutPhonemizer(language='en-us') | |
# models for prompt | |
asr_model = whisper.load_model("tiny.en") | |
semantic_tokenizer = SemanticTokenizer( | |
hubert_path=args.hubert_path, | |
quantizer_path=args.quantizer_path, | |
duplicate=True) | |
prompt_result = get_prompt( | |
prompt_wav_path=args.prompt_wav_path, | |
asr_model=asr_model, | |
phonemizer=phonemizer, | |
semantic_tokenizer=semantic_tokenizer) | |
# zero prompt => 输出的 semantic 包含的内容是对的但是音色是乱的 | |
# (B, 1) | |
# prompt = torch.ones( | |
# batch['phoneme_ids'].size(0), 1, dtype=torch.int32) * 0 | |
prompt = prompt_result['prompt_semantic_tokens'] | |
prompt_phoneme_ids_len = prompt_result['prompt_phoneme_ids_len'] | |
prompt_phoneme_ids = prompt_result['prompt_phoneme_ids'] | |
sentences = [] | |
with open(args.text_file, 'rt', encoding='utf-8') as f: | |
for line in f: | |
if line.strip() != "": | |
items = re.split(r"\s+", line.strip(), 1) | |
utt_id = items[0] | |
sentence = " ".join(items[1:]) | |
sentences.append((utt_id, sentence)) | |
semantic_data = [['item_name', 'semantic_audio']] | |
for utt_id, sentence in sentences[1:]: | |
# 需要自己构造伪 batch 输入给模型 | |
batch = get_batch(sentence, phonemizer) | |
# prompt 和真正的输入拼接 | |
all_phoneme_ids = torch.cat( | |
[prompt_phoneme_ids, batch['phoneme_ids']], dim=1) | |
# 或者可以直接求 all_phoneme_ids 的 shape[-1] | |
all_phoneme_len = prompt_phoneme_ids_len + batch['phoneme_ids_len'] | |
st = time.time() | |
with torch.no_grad(): | |
pred_semantic = t2s_model.model.infer( | |
all_phoneme_ids.cuda(), | |
all_phoneme_len.cuda(), | |
prompt.cuda(), | |
top_k=config['inference']['top_k'], | |
early_stop_num=hz * max_sec) | |
print(f'{time.time() - st} sec used in T2S') | |
# 删除 prompt 对应的部分 | |
prompt_len = prompt.shape[-1] | |
pred_semantic = pred_semantic[:, prompt_len:] | |
# bs = 1 | |
pred_semantic = pred_semantic[0] | |
semantic_token = pred_semantic.detach().cpu().numpy().tolist() | |
semantic_token_str = ' '.join(str(x) for x in semantic_token) | |
semantic_data.append([utt_id, semantic_token_str]) | |
delimiter = '\t' | |
filename = output_dir / f'{utt_id}_p_{prompt_result["prompt_name"]}_semantic_token.tsv' | |
with open(filename, 'w', encoding='utf-8') as writer: | |
for row in semantic_data: | |
line = delimiter.join(row) | |
writer.write(line + '\n') | |
# clean semantic token for next setence | |
semantic_data = [['item_name', 'semantic_audio']] | |
if __name__ == "__main__": | |
main() | |