import torch import logging import sys, os, pdb import torch.nn.functional as F from pathlib import Path sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]))) sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'emotion')) from wavlm_emotion import WavLMWrapper from whisper_emotion import WhisperWrapper # define logging console import logging logging.basicConfig( format='%(asctime)s %(levelname)-3s ==> %(message)s', level=logging.INFO, datefmt='%Y-%m-%d %H:%M:%S' ) os.environ["MKL_NUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" os.environ["OMP_NUM_THREADS"] = "1" if __name__ == '__main__': label_list = [ 'Anger', 'Contempt', 'Disgust', 'Fear', 'Happiness', 'Neutral', 'Sadness', 'Surprise', 'Other' ] # Find device device = torch.device("cuda") if torch.cuda.is_available() else "cpu" if torch.cuda.is_available(): print('GPU available, use GPU') # Define the model # Note that ensemble yields the better performance than the single model # Define the model wrapper model_path = "model" wavlm_model = model = WavLMWrapper( pretrain_model="wavlm_large", finetune_method="finetune", output_class_num=9, freeze_params=True, use_conv_output=True, detailed_class_num=17 ).to(device) whisper_model = WhisperWrapper( pretrain_model="whisper_large", finetune_method="lora", lora_rank=16, output_class_num=9, freeze_params=True, use_conv_output=True, detailed_class_num=17 ).to(device) whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion.pt"), weights_only=True), strict=False) whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion_lora.pt")), strict=False) wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_emotion.pt"), weights_only=True), strict=False) wavlm_model.eval() whisper_model.eval() # Audio must be 16k Hz data = torch.zeros([1, 16000]).to(device) whisper_logits, whisper_embedding, _, _, _, _ = whisper_model( data, return_feature=True ) wavlm_logits, wavlm_embedding, _, _, _, _ = wavlm_model( data, return_feature=True ) ensemble_logits = (whisper_logits + wavlm_logits) / 2 ensemble_prob = F.softmax(ensemble_logits, dim=1) print(ensemble_prob.shape) print(whisper_embedding.shape) print(wavlm_embedding.shape) print(label_list[torch.argmax(ensemble_prob).detach().cpu().item()])