Knight-coderr commited on
Commit
c26339d
·
verified ·
1 Parent(s): 06f4a76

Update utils/noise_classification.py

Browse files
Files changed (1) hide show
  1. utils/noise_classification.py +63 -63
utils/noise_classification.py CHANGED
@@ -1,63 +1,63 @@
1
- import numpy as np
2
- import torchaudio
3
- import torchaudio.transforms as T
4
- import joblib
5
- from scipy.stats import skew, kurtosis
6
- import tensorflow_hub as hub
7
-
8
- # Load classifier and label encoder
9
- clf = joblib.load("models/noise_classifier.pkl")
10
- label_encoder = joblib.load("models/label_encoder.pkl")
11
-
12
- # Load YAMNet model
13
- yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
14
-
15
- def get_yamnet_embedding(audio_path):
16
- """
17
- Extract YAMNet embeddings with statistical pooling from a WAV file.
18
- """
19
- try:
20
- waveform, sr = torchaudio.load(audio_path)
21
- if sr != 16000:
22
- resampler = T.Resample(orig_freq=sr, new_freq=16000)
23
- waveform = resampler(waveform)
24
- if waveform.size(0) > 1:
25
- waveform = waveform.mean(dim=0)
26
- else:
27
- waveform = waveform.squeeze(0)
28
-
29
- waveform_np = waveform.numpy()
30
- _, embeddings, _ = yamnet_model(waveform_np)
31
-
32
- # Statistical features
33
- mean = np.mean(embeddings, axis=0)
34
- std = np.std(embeddings, axis=0)
35
- min_val = np.min(embeddings, axis=0)
36
- max_val = np.max(embeddings, axis=0)
37
- skewness = skew(embeddings, axis=0)
38
- kurt = kurtosis(embeddings, axis=0)
39
-
40
- return np.concatenate([mean, std, min_val, max_val, skewness, kurt])
41
- except Exception as e:
42
- print(f"Failed to process {audio_path}: {e}")
43
- return None
44
-
45
- def classify_noise(audio_path, threshold=0.6):
46
- """
47
- Classify noise with rejection threshold for 'Unknown' label.
48
- """
49
- feature = get_yamnet_embedding(audio_path)
50
- if feature is None:
51
- return [("Unknown", 0.0)]
52
-
53
- feature = feature.reshape(1, -1)
54
- probs = clf.predict_proba(feature)[0]
55
-
56
- top_idx = np.argmax(probs)
57
- top_prob = probs[top_idx]
58
-
59
- if top_prob < threshold:
60
- return [("Unknown", top_prob)]
61
-
62
- top_indices = np.argsort(probs)[::-1][:5]
63
- return [(label_encoder.inverse_transform([i])[0], probs[i]) for i in top_indices]
 
1
+ import numpy as np
2
+ import torchaudio
3
+ import torchaudio.transforms as T
4
+ import joblib
5
+ from scipy.stats import skew, kurtosis
6
+ import tensorflow_hub as hub
7
+
8
+ # Load classifier and label encoder
9
+ clf = joblib.load("models/noise_classifier.pkl")
10
+ label_encoder = joblib.load("models/label_encoder.pkl")
11
+
12
+ # Load YAMNet model
13
+ yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1")
14
+
15
+ def get_yamnet_embedding(audio_path):
16
+ """
17
+ Extract YAMNet embeddings with statistical pooling from a WAV file.
18
+ """
19
+ try:
20
+ waveform, sr = torchaudio.load(audio_path)
21
+ if sr != 16000:
22
+ resampler = T.Resample(orig_freq=sr, new_freq=16000)
23
+ waveform = resampler(waveform)
24
+ if waveform.size(0) > 1:
25
+ waveform = waveform.mean(dim=0)
26
+ else:
27
+ waveform = waveform.squeeze(0)
28
+
29
+ waveform_np = waveform.numpy()
30
+ _, embeddings, _ = yamnet_model(waveform_np)
31
+
32
+ # Statistical features
33
+ mean = np.mean(embeddings, axis=0)
34
+ std = np.std(embeddings, axis=0)
35
+ min_val = np.min(embeddings, axis=0)
36
+ max_val = np.max(embeddings, axis=0)
37
+ skewness = skew(embeddings, axis=0)
38
+ kurt = kurtosis(embeddings, axis=0)
39
+
40
+ return np.concatenate([mean, std, min_val, max_val, skewness, kurt])
41
+ except Exception as e:
42
+ print(f"Failed to process {audio_path}: {e}")
43
+ return None
44
+
45
+ def classify_noise(audio_path, threshold=0.6):
46
+ """
47
+ Classify noise with rejection threshold for 'Unknown' label.
48
+ """
49
+ feature = get_yamnet_embedding(audio_path)
50
+ if feature is None:
51
+ return [("Unknown", 0.0)]
52
+
53
+ feature = feature.reshape(1, -1)
54
+ probs = clf.predict_proba(feature)[0]
55
+
56
+ top_idx = np.argmax(probs)
57
+ top_prob = probs[top_idx]
58
+
59
+ # if top_prob < threshold:
60
+ # return [("Unknown", top_prob)]
61
+
62
+ top_indices = np.argsort(probs)[::-1][:5]
63
+ return [(label_encoder.inverse_transform([i])[0], probs[i]) for i in top_indices]