File size: 2,002 Bytes
c26339d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
import torchaudio
import torchaudio.transforms as T
import joblib
from scipy.stats import skew, kurtosis
import tensorflow_hub as hub

# Load classifier and label encoder
clf = joblib.load("models/noise_classifier.pkl")
label_encoder = joblib.load("models/label_encoder.pkl")

# Load YAMNet model
yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")

def get_yamnet_embedding(audio_path):
    """
    Extract YAMNet embeddings with statistical pooling from a WAV file.
    """
    try:
        waveform, sr = torchaudio.load(audio_path)
        if sr != 16000:
            resampler = T.Resample(orig_freq=sr, new_freq=16000)
            waveform = resampler(waveform)
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0)
        else:
            waveform = waveform.squeeze(0)
        
        waveform_np = waveform.numpy()
        _, embeddings, _ = yamnet_model(waveform_np)

        # Statistical features
        mean = np.mean(embeddings, axis=0)
        std = np.std(embeddings, axis=0)
        min_val = np.min(embeddings, axis=0)
        max_val = np.max(embeddings, axis=0)
        skewness = skew(embeddings, axis=0)
        kurt = kurtosis(embeddings, axis=0)
        
        return np.concatenate([mean, std, min_val, max_val, skewness, kurt])
    except Exception as e:
        print(f"Failed to process {audio_path}: {e}")
        return None

def classify_noise(audio_path, threshold=0.6):
    """
    Classify noise with rejection threshold for 'Unknown' label.
    """
    feature = get_yamnet_embedding(audio_path)
    if feature is None:
        return [("Unknown", 0.0)]

    feature = feature.reshape(1, -1)
    probs = clf.predict_proba(feature)[0]
    
    top_idx = np.argmax(probs)
    top_prob = probs[top_idx]
    
    # if top_prob < threshold:
    #     return [("Unknown", top_prob)]

    top_indices = np.argsort(probs)[::-1][:5]
    return [(label_encoder.inverse_transform([i])[0], probs[i]) for i in top_indices]