File size: 1,169 Bytes
c914273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torchaudio
from preprocessing.preprocess import AudioPipeline
from dancer_net.dancer_net import ShortChunkCNN
import torch
import numpy as np
import os
import json

if __name__ == "__main__":

    audio_file = "data/samples/mzm.iqskzxzx.aac.p.m4a.wav"
    seconds = 6
    model_path = "logs/20221226-230930"
    weights = os.path.join(model_path, "dancer_net.pt")
    config_path = os.path.join(model_path, "config.json")
    device = "mps"
    threshold = 0.5

    with open(config_path) as f:
        config = json.load(f)
    labels = np.array(sorted(config["classes"]))

    audio_pipeline = AudioPipeline()
    waveform, sample_rate = torchaudio.load(audio_file)
    waveform = waveform[:, :seconds * sample_rate]
    spectrogram = audio_pipeline(waveform)
    spectrogram = spectrogram.unsqueeze(0).to(device)

    model = ShortChunkCNN(n_class=len(labels))
    model.load_state_dict(torch.load(weights))
    model = model.to(device).eval()

    with torch.no_grad():
        results = model(spectrogram)
    results = results.squeeze(0).detach().cpu().numpy()
    results = results > threshold
    results = labels[results]
    print(results)