Spaces:
Runtime error
Runtime error
import torchaudio | |
from preprocessing.preprocess import AudioPipeline | |
from dancer_net.dancer_net import ShortChunkCNN | |
import torch | |
import numpy as np | |
import os | |
import json | |
if __name__ == "__main__": | |
audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav" | |
seconds = 6 | |
model_path = "logs/20221226-230930" | |
weights = os.path.join(model_path, "dancer_net.pt") | |
config_path = os.path.join(model_path, "config.json") | |
device = "mps" | |
threshold = 0.5 | |
with open(config_path) as f: | |
config = json.load(f) | |
labels = np.array(sorted(config["classes"])) | |
audio_pipeline = AudioPipeline() | |
waveform, sample_rate = torchaudio.load(audio_file) | |
waveform = waveform[:, :seconds * sample_rate] | |
spectrogram = audio_pipeline(waveform) | |
spectrogram = spectrogram.unsqueeze(0).to(device) | |
model = ShortChunkCNN(n_class=len(labels)) | |
model.load_state_dict(torch.load(weights)) | |
model = model.to(device).eval() | |
with torch.no_grad(): | |
results = model(spectrogram) | |
results = results.squeeze(0).detach().cpu().numpy() | |
results = results > threshold | |
results = labels[results] | |
print(results) | |