voj / audio_class_predictor.py
Your Name
update app.py
dddb9f9
raw
history blame
1.35 kB
import timm
import json
import torch
from torchaudio.functional import resample
import numpy as np
from torchaudio.compliance import kaldi
import torch.nn.functional as F
import requests
TAG = "gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft_as20k"
MODEL = timm.create_model(f"hf_hub:{TAG}", pretrained=True).eval()
LABEL_URL = "https://huggingface.co/datasets/huggingface/label-files/raw/main/audioset-id2label.json"
AUDIOSET_LABELS = list(json.loads(requests.get(LABEL_URL).content).values())
SAMPLING_RATE = 16_000
MEAN = -4.2677393
STD = 4.5689974
def preprocess(x: torch.Tensor):
x = x - x.mean()
melspec = kaldi.fbank(x.unsqueeze(0), htk_compat=True, window_type="hanning", num_mel_bins=128)
if melspec.shape[0] < 1024:
melspec = F.pad(melspec, (0, 0, 0, 1024 - melspec.shape[0]))
else:
melspec = melspec[:1024]
melspec = (melspec - MEAN) / (STD * 2)
return melspec
def predict_class(x: np.ndarray):
x = torch.from_numpy(x)
if x.ndim > 1:
x = x.mean(-1)
assert x.ndim == 1
x = preprocess(x)
with torch.inference_mode():
logits = MODEL(x.view(1, 1, 1024, 128)).squeeze(0)
topk_probs, topk_classes = logits.sigmoid().topk(10)
preds = [[AUDIOSET_LABELS[cls], prob.item()*100] for cls, prob in zip(topk_classes, topk_probs)]
return preds