File size: 2,856 Bytes
cac3ec7
 
 
cafc237
 
cac3ec7
 
 
 
 
5b04966
cac3ec7
 
 
 
 
 
 
 
bcc0935
 
cac3ec7
 
 
bcc0935
cac3ec7
 
 
 
 
bcc0935
cafc237
cac3ec7
 
bcc0935
cac3ec7
bcc0935
 
 
5b04966
 
 
 
 
 
cac3ec7
 
cafc237
cac3ec7
bcc0935
cafc237
cac3ec7
cafc237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cac3ec7
bcc0935
cafc237
 
bcc0935
 
cafc237
 
 
 
 
 
 
 
bcc0935
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import json

import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
from torchaudio.functional import resample

TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()

LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())

SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974


def preprocess(x: torch.Tensor):
    x = x - x.mean()
    melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
    if melspec.shape[0] < 1024:
        melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
    else:
        melspec = melspec[:1024]
    melspec = (melspec - MEAN) / (STD * 2)
    return melspec


def predict(audio, start):
    sr, x = audio
    if x.shape[0] < start * sr:
        raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")

    x = torch.from_numpy(x) / (1 << 15)
    if x.ndim > 1:
        x = x.mean(-1)
    assert x.ndim == 1
    x = resample(x[int(start * sr) :], sr, SAMPLING_RATE)
    x = preprocess(x)

    with torch.inference_mode():
        logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)

    topk_probs, topk_classes = logits.sigmoid().topk(10)
    preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]

    fig = plt.figure()
    plt.imshow(x.T, origin="lower")
    plt.title("Log mel-spectrogram")
    plt.xlabel("Time (s)")
    plt.xticks(np.arange(11) * 100, np.arange(11))
    plt.yticks([0, 64, 128])
    plt.tight_layout()

    return preds, fig


DESCRIPTION = """
Classify audio into AudioSet classes with ViT-B/16 pre-trained using AudioMAE objective.

- For more information about AudioMAE, visit https://github.com/facebookresearch/AudioMAE.
- For how to use AudioMAE model in timm, visit https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k.

Input audio is converted to log Mel-spectrogram and treated as a grayscale image. The model is a vanilla ViT-B/16.

NOTE: AudioMAE model only accepts 10s audio (10.24 to be exact). Longer audio will be cropped. Shorted audio will be zero-padded.
"""

gr.Interface(
    title="AudioSet classification with AudioMAE (ViT-B/16)",
    description=DESCRIPTION,
    fn=predict,
    inputs=["audio", "number"],
    outputs=[
        gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"),
        gr.Plot(label="spectrogram"),
    ],
    examples=[
        ["LS_female_1462-170138-0008.flac", 0],
        ["LS_male_3170-137482-0005.flac", 0],
    ],
).launch()