Arousal - Dominance - Valence

Dimensional Speech Emotion Recognition model of simultaneous use of wavlm / wav2vec2.0. Achieves 0.6760566 valence CCC on MSP Podcast Test 1. Used as teacher for wav2small ...

PapersWithCode / arXiv

Wav2Small: Distilling Wav2Vec2 to 72K parameters for low-resource
speech emotion recognition.
D. Kounadis-Bastian, O. Schrüfer, A. Derington, H. Wierstorf,
F. Eyben, F. Burkhardt, B.W. Schuller. 2024, arXiV Preprint
CCC MSP Podcast v1.7
Test 1Test 2
Val Dom Aro Val Dom Aro
0.6760566 0.6840044 0.7620181 0.4229267 0.4684658 0.4857733

HowTo

import librosa
import torch
import types
import torch.nn as nn
from transformers import AutoModelForAudioClassification
from transformers.models.wav2vec2.modeling_wav2vec2 import (Wav2Vec2Model,
                                                  Wav2Vec2PreTrainedModel)


signal = torch.from_numpy(
    librosa.load('test.wav', sr=16000)[0])[None, :]
device = 'cpu'

class ADV(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):

        x = self.dense(x)
        x = torch.tanh(x)

        return self.out_proj(x)


class Dawn(Wav2Vec2PreTrainedModel):
    r"""https://arxiv.org/abs/2203.07378"""

    def __init__(self, config):

        super().__init__(config)

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = ADV(config)

    def forward(self, x):
        x -= x.mean(1, keepdim=True)
        variance = (x * x).mean(1, keepdim=True) + 1e-7
        x = self.wav2vec2(x / variance.sqrt())
        return self.classifier(x.last_hidden_state.mean(1))


def _forward(self, x):
    '''x: (batch, audio-samples-16KHz)'''
    x = (x + self.config.mean) / self.config.std  # sgn
    x = self.ssl_model(x, attention_mask=None).last_hidden_state
    # pool
    h = self.pool_model.sap_linear(x).tanh()
    w = torch.matmul(h, self.pool_model.attention).softmax(1)
    mu = (x * w).sum(1)
    x = torch.cat(
        [
            mu,
            ((x * x * w).sum(1) - mu * mu).clamp(min=1e-7).sqrt()
        ], 1)
    return self.ser_model(x)


# WavLM

base = AutoModelForAudioClassification.from_pretrained(
        '3loi/SER-Odyssey-Baseline-WavLM-Multi-Attributes',
        trust_remote_code=True).to(device).eval()
base.forward = types.MethodType(_forward, base)

# Wav2Vec2

dawn = Dawn.from_pretrained(
    'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim'
).to(device).eval()


def wav2small(x):
    return .5 * dawn(x) + .5 * base(x)

pred = wav2small(signal.to(device))
print(f'Arousal={pred[0, 0]} '
      f'Dominance={pred[0, 1]} ',
      f'Valence={pred[0, 2]}')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.