Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,744 Bytes
dd9600d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import logging
import sys, os, pdb
import torch.nn.functional as F
from pathlib import Path
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1])))
sys.path.append(os.path.join(str(Path(os.path.realpath(__file__)).parents[1]), 'model', 'emotion'))
from wavlm_emotion import WavLMWrapper
from whisper_emotion import WhisperWrapper
# define logging console
import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-3s ==> %(message)s',
level=logging.INFO,
datefmt='%Y-%m-%d %H:%M:%S'
)
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
if __name__ == '__main__':
label_list = [
'Anger',
'Contempt',
'Disgust',
'Fear',
'Happiness',
'Neutral',
'Sadness',
'Surprise',
'Other'
]
# Find device
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available(): print('GPU available, use GPU')
# Define the model
# Note that ensemble yields the better performance than the single model
# Define the model wrapper
model_path = "model"
wavlm_model = model = WavLMWrapper(
pretrain_model="wavlm_large",
finetune_method="finetune",
output_class_num=9,
freeze_params=True,
use_conv_output=True,
detailed_class_num=17
).to(device)
whisper_model = WhisperWrapper(
pretrain_model="whisper_large",
finetune_method="lora",
lora_rank=16,
output_class_num=9,
freeze_params=True,
use_conv_output=True,
detailed_class_num=17
).to(device)
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion.pt"), weights_only=True), strict=False)
whisper_model.load_state_dict(torch.load(os.path.join(model_path, f"whisper_emotion_lora.pt")), strict=False)
wavlm_model.load_state_dict(torch.load(os.path.join(model_path, f"wavlm_emotion.pt"), weights_only=True), strict=False)
wavlm_model.eval()
whisper_model.eval()
# Audio must be 16k Hz
data = torch.zeros([1, 16000]).to(device)
whisper_logits, whisper_embedding, _, _, _, _ = whisper_model(
data, return_feature=True
)
wavlm_logits, wavlm_embedding, _, _, _, _ = wavlm_model(
data, return_feature=True
)
ensemble_logits = (whisper_logits + wavlm_logits) / 2
ensemble_prob = F.softmax(ensemble_logits, dim=1)
print(ensemble_prob.shape)
print(whisper_embedding.shape)
print(wavlm_embedding.shape)
print(label_list[torch.argmax(ensemble_prob).detach().cpu().item()])
|