Spaces:
Sleeping
Sleeping
File size: 2,856 Bytes
cac3ec7 cafc237 cac3ec7 5b04966 cac3ec7 bcc0935 cac3ec7 bcc0935 cac3ec7 bcc0935 cafc237 cac3ec7 bcc0935 cac3ec7 bcc0935 5b04966 cac3ec7 cafc237 cac3ec7 bcc0935 cafc237 cac3ec7 cafc237 cac3ec7 bcc0935 cafc237 bcc0935 cafc237 bcc0935 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import json
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import requests
import timm
import torch
import torch.nn.functional as F
from torchaudio.compliance import kaldi
from torchaudio.functional import resample
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974
def preprocess(x: torch.Tensor):
x = x - x.mean()
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
melspec = (melspec - MEAN) / (STD * 2)
return melspec
def predict(audio, start):
sr, x = audio
if x.shape[0] < start * sr:
raise gr.Error(f"`start` ({start}) must be smaller than audio duration ({x.shape[0] / sr:.0f}s)")
x = torch.from_numpy(x) / (1 << 15)
if x.ndim > 1:
x = x.mean(-1)
assert x.ndim == 1
x = resample(x[int(start * sr) :], sr, SAMPLING_RATE)
x = preprocess(x)
with torch.inference_mode():
logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)
topk_probs, topk_classes = logits.sigmoid().topk(10)
preds = [[AUDIOSET_LABELS[cls], prob.item() * 100] for cls, prob in zip(topk_classes, topk_probs)]
fig = plt.figure()
plt.imshow(x.T, origin="lower")
plt.title("Log mel-spectrogram")
plt.xlabel("Time (s)")
plt.xticks(np.arange(11) * 100, np.arange(11))
plt.yticks([0, 64, 128])
plt.tight_layout()
return preds, fig
DESCRIPTION = """
Classify audio into AudioSet classes with ViT-B/16 pre-trained using AudioMAE objective.
- For more information about AudioMAE, visit https://github.com/facebookresearch/AudioMAE.
- For how to use AudioMAE model in timm, visit https://huggingface.co/gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k.
Input audio is converted to log Mel-spectrogram and treated as a grayscale image. The model is a vanilla ViT-B/16.
NOTE: AudioMAE model only accepts 10s audio (10.24 to be exact). Longer audio will be cropped. Shorted audio will be zero-padded.
"""
gr.Interface(
title="AudioSet classification with AudioMAE (ViT-B/16)",
description=DESCRIPTION,
fn=predict,
inputs=["audio", "number"],
outputs=[
gr.Dataframe(headers=["class", "score"], row_count=10, label="prediction"),
gr.Plot(label="spectrogram"),
],
examples=[
["LS_female_1462-170138-0008.flac", 0],
["LS_male_3170-137482-0005.flac", 0],
],
).launch()
|