CapSpeech-TTS / capspeech /eval /src /example /dialect_world_dialect.py
OpenSound's picture
Upload 518 files
dd9600d verified
import torch
import sys, os, pdb
import argparse, logging
import torch.nn.functional as F
from pathlib import Path
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'dialect'))
from wavlm_dialect import WavLMWrapper
from whisper_dialect import WhisperWrapper
# define logging console
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-3s ==> %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if __name__ == '__main__':
label_list = [
'East Asia', 'English', 'Germanic', 'Irish',
'North America', 'Northern Irish', 'Oceania',
'Other', 'Romance', 'Scottish', 'Semitic', 'Slavic',
'South African', 'Southeast Asia', 'South Asia', 'Welsh'
]
# Find device
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available(): print('GPU available, use GPU')
# Define the model
# Note that ensemble yields the better performance than the single model
model_path = "YOUR_PATH"
# Define the model wrapper
wavlm_model = model = WavLMWrapper(
pretrain_model="wavlm_large",
finetune_method="lora",
lora_rank=16,
output_class_num=16,
freeze_params=False,
use_conv_output=True,
apply_gradient_reversal=False,
num_dataset=3
).to(device)
whisper_model = WhisperWrapper(
pretrain_model="whisper_large",
finetune_method="lora",
lora_rank=16,
output_class_num=16,
freeze_params=False,
use_conv_output=True,
apply_gradient_reversal=False,
num_dataset=11
).to(device)
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect.pt"), weights_only=True), strict=False)
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect_lora.pt")), strict=False)
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect.pt"), weights_only=True), strict=False)
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect_lora.pt")), strict=False)
wavlm_model.eval()
whisper_model.eval()
data = torch.zeros([1, 16000]).to(device)
wavlm_logits, wavlm_embeddings = wavlm_model(data, return_feature=True)
whisper_logits, whisper_embeddings = whisper_model(data, return_feature=True)
ensemble_logits = (wavlm_logits + whisper_logits) / 2
ensemble_prob = F.softmax(ensemble_logits, dim=1)
pred = label_list[ensemble_prob.argmax(-1)]
print(pred)