Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import sys, os, pdb | |
import argparse, logging | |
import torch.nn.functional as F | |
from pathlib import Path | |
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]))) | |
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'dialect')) | |
from wavlm_dialect import WavLMWrapper | |
from whisper_dialect import WhisperWrapper | |
# define logging console | |
import logging | |
logging.basicConfig( | |
format='%(asctime)s %(levelname)-3s ==> %(message)s', | |
level=logging.INFO, | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
os.environ["OMP_NUM_THREADS"] = "1" | |
if __name__ == '__main__': | |
label_list = [ | |
'East Asia', 'English', 'Germanic', 'Irish', | |
'North America', 'Northern Irish', 'Oceania', | |
'Other', 'Romance', 'Scottish', 'Semitic', 'Slavic', | |
'South African', 'Southeast Asia', 'South Asia', 'Welsh' | |
] | |
# Find device | |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
if torch.cuda.is_available(): print('GPU available, use GPU') | |
# Define the model | |
# Note that ensemble yields the better performance than the single model | |
model_path = "YOUR_PATH" | |
# Define the model wrapper | |
wavlm_model = model = WavLMWrapper( | |
pretrain_model="wavlm_large", | |
finetune_method="lora", | |
lora_rank=16, | |
output_class_num=16, | |
freeze_params=False, | |
use_conv_output=True, | |
apply_gradient_reversal=False, | |
num_dataset=3 | |
).to(device) | |
whisper_model = WhisperWrapper( | |
pretrain_model="whisper_large", | |
finetune_method="lora", | |
lora_rank=16, | |
output_class_num=16, | |
freeze_params=False, | |
use_conv_output=True, | |
apply_gradient_reversal=False, | |
num_dataset=11 | |
).to(device) | |
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect.pt"), weights_only=True), strict=False) | |
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_world_dialect_lora.pt")), strict=False) | |
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect.pt"), weights_only=True), strict=False) | |
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_world_dialect_lora.pt")), strict=False) | |
wavlm_model.eval() | |
whisper_model.eval() | |
data = torch.zeros([1, 16000]).to(device) | |
wavlm_logits, wavlm_embeddings = wavlm_model(data, return_feature=True) | |
whisper_logits, whisper_embeddings = whisper_model(data, return_feature=True) | |
ensemble_logits = (wavlm_logits + whisper_logits) / 2 | |
ensemble_prob = F.softmax(ensemble_logits, dim=1) | |
pred = label_list[ensemble_prob.argmax(-1)] | |
print(pred) | |