Spaces:
Runtime error
Runtime error
File size: 5,059 Bytes
e79b770 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# test from dump file
import argparse
import time
from pathlib import Path
import numpy as np
import torch
from AR.data.dataset import Text2SemanticDataset
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
from torch.utils.data import DataLoader
def parse_args():
# parse args and config
parser = argparse.ArgumentParser(
description="Run SoundStorm AR S1 model for test set.")
parser.add_argument(
'--config_file',
type=str,
default='conf/default.yaml',
help='path of config file')
# args for dataset
parser.add_argument(
'--test_semantic_path',
type=str,
default='dump/test/semantic_token.tsv')
parser.add_argument(
'--test_phoneme_path', type=str, default='dump/test/phonemes.npy')
parser.add_argument(
'--ckpt_path',
type=str,
default='exp/default/ckpt/epoch=99-step=49000.ckpt',
help='Checkpoint file of SoundStorm AR S1 model.')
parser.add_argument("--output_dir", type=str, help="output dir.")
args = parser.parse_args()
return args
def main():
args = parse_args()
config = load_yaml_config(args.config_file)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
batch_size = 1
hz = 50
max_sec = config['data']['max_sec']
# get dataset
test_dataset = Text2SemanticDataset(
phoneme_path=args.test_phoneme_path,
semantic_path=args.test_semantic_path,
# max_sec 需要与训练时保持一致,不然可能会效果不好,重复漏字等
# 但是这里设置太短又会直接过滤掉太长的样本,为了防止被过滤掉,可以在 infer 的时候截断
max_sec=100,
max_sample=8,
pad_val=config['data']['pad_val'])
# get model
t2s_model = Text2SemanticLightningModule.load_from_checkpoint(
checkpoint_path=args.ckpt_path, config=config)
t2s_model.cuda()
t2s_model.eval()
# 获取 batch_size 条
# 创建 DataLoader,并指定 collate_fn 函数
dataloader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=test_dataset.collate)
item_names = test_dataset.__get_item_names__()
# 逐批次读取数据, bs=1、shuffle=False 时可以用 __get_item_names__ 对应
semantic_data = [['item_name', 'semantic_audio']]
for i, batch in enumerate(dataloader):
# 要保证 bs = 1
utt_id = item_names[i]
if i == 0:
print("utt_id:", utt_id)
# bs > 1 时会补零
# 与 validation_step() 保持一致
semantic_len = batch['semantic_ids'].size(1)
# 以 batch['semantic_ids'] 的前 150 个为 prompt
# 多次合成,前 prompt_len 个是一样的,而且和 prompt 一样
prompt_len = min(int(semantic_len * 0.5), 150)
# 输入纯文本时 prompt 该输入什么?=> see t2s.py
prompt = batch['semantic_ids'][:, :prompt_len]
# # zero prompt => 也可以输出文本内容正确的 semantic token, 但是音色是乱的
# 证明 semantic token 中还是包含了音色信息
# prompt = torch.ones(
# batch['semantic_ids'].size(0), 1, dtype=torch.int32) * 0
# print("prompt:", prompt)
# print("prompt.shape:", prompt.shape)
np.save(output_dir / 'prompt.npy', prompt.detach().cpu().numpy())
st = time.time()
with torch.no_grad():
# calculate acc for test
loss, acc = t2s_model.model.forward(
batch['phoneme_ids'].cuda(),
batch['phoneme_ids_len'].cuda(),
batch['semantic_ids'].cuda(),
batch['semantic_ids_len'].cuda())
print("top_3_acc of this batch:", acc)
pred_semantic = t2s_model.model.infer(
batch['phoneme_ids'].cuda(),
batch['phoneme_ids_len'].cuda(),
prompt.cuda(),
top_k=config['inference']['top_k'],
# hz * max_sec in train dataloader
# 生成的长度是 1002 应该是有一些 pad
early_stop_num=hz * max_sec)
# bs = 1
pred_semantic = pred_semantic[0]
print(f'{time.time() - st} sec used in T2S')
semantic_token = pred_semantic.detach().cpu().numpy().tolist()
semantic_token_str = ' '.join(str(x) for x in semantic_token)
semantic_data.append([utt_id, semantic_token_str])
else:
break
delimiter = '\t'
filename = output_dir / "semantic_token.tsv"
with open(filename, 'w', encoding='utf-8') as writer:
for row in semantic_data:
line = delimiter.join(row)
writer.write(line + '\n')
if __name__ == "__main__":
main()
|