Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import logging | |
import sys, os, pdb | |
import torch.nn.functional as F | |
from pathlib import Path | |
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]))) | |
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'emotion')) | |
from wavlm_emotion import WavLMWrapper | |
from whisper_emotion import WhisperWrapper | |
# define logging console | |
import logging | |
logging.basicConfig( | |
format='%(asctime)s %(levelname)-3s ==> %(message)s', | |
level=logging.INFO, | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
os.environ["MKL_NUM_THREADS"] = "1" | |
os.environ["NUMEXPR_NUM_THREADS"] = "1" | |
os.environ["OMP_NUM_THREADS"] = "1" | |
if __name__ == '__main__': | |
label_list = [ | |
'Anger', | |
'Contempt', | |
'Disgust', | |
'Fear', | |
'Happiness', | |
'Neutral', | |
'Sadness', | |
'Surprise', | |
'Other' | |
] | |
# Find device | |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
if torch.cuda.is_available(): print('GPU available, use GPU') | |
# Define the model | |
# Note that ensemble yields the better performance than the single model | |
# Define the model wrapper | |
model_path = "model" | |
wavlm_model = model = WavLMWrapper( | |
pretrain_model="wavlm_large", | |
finetune_method="finetune", | |
output_class_num=9, | |
freeze_params=True, | |
use_conv_output=True, | |
detailed_class_num=17 | |
).to(device) | |
whisper_model = WhisperWrapper( | |
pretrain_model="whisper_large", | |
finetune_method="lora", | |
lora_rank=16, | |
output_class_num=9, | |
freeze_params=True, | |
use_conv_output=True, | |
detailed_class_num=17 | |
).to(device) | |
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion.pt"), weights_only=True), strict=False) | |
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion_lora.pt")), strict=False) | |
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_emotion.pt"), weights_only=True), strict=False) | |
wavlm_model.eval() | |
whisper_model.eval() | |
# Audio must be 16k Hz | |
data = torch.zeros([1, 16000]).to(device) | |
whisper_logits, whisper_embedding, _, _, _, _ = whisper_model( | |
data, return_feature=True | |
) | |
wavlm_logits, wavlm_embedding, _, _, _, _ = wavlm_model( | |
data, return_feature=True | |
) | |
ensemble_logits = (whisper_logits + wavlm_logits) / 2 | |
ensemble_prob = F.softmax(ensemble_logits, dim=1) | |
print(ensemble_prob.shape) | |
print(whisper_embedding.shape) | |
print(wavlm_embedding.shape) | |
print(label_list[torch.argmax(ensemble_prob).detach().cpu().item()]) | |