import numpy as np import torchaudio import torchaudio.transforms as T import joblib from scipy.stats import skew, kurtosis import tensorflow_hub as hub # Load classifier and label encoder clf = joblib.load("models/noise_classifier.pkl") label_encoder = joblib.load("models/label_encoder.pkl") # Load YAMNet model yamnet_model = hub.load("https://tfhub.dev/google/yamnet/1") def get_yamnet_embedding(audio_path): """ Extract YAMNet embeddings with statistical pooling from a WAV file. """ try: waveform, sr = torchaudio.load(audio_path) if sr != 16000: resampler = T.Resample(orig_freq=sr, new_freq=16000) waveform = resampler(waveform) if waveform.size(0) > 1: waveform = waveform.mean(dim=0) else: waveform = waveform.squeeze(0) waveform_np = waveform.numpy() _, embeddings, _ = yamnet_model(waveform_np) # Statistical features mean = np.mean(embeddings, axis=0) std = np.std(embeddings, axis=0) min_val = np.min(embeddings, axis=0) max_val = np.max(embeddings, axis=0) skewness = skew(embeddings, axis=0) kurt = kurtosis(embeddings, axis=0) return np.concatenate([mean, std, min_val, max_val, skewness, kurt]) except Exception as e: print(f"Failed to process {audio_path}: {e}") return None def classify_noise(audio_path, threshold=0.6): """ Classify noise with rejection threshold for 'Unknown' label. """ feature = get_yamnet_embedding(audio_path) if feature is None: return [("Unknown", 0.0)] feature = feature.reshape(1, -1) probs = clf.predict_proba(feature)[0] top_idx = np.argmax(probs) top_prob = probs[top_idx] # if top_prob < threshold: # return [("Unknown", top_prob)] top_indices = np.argsort(probs)[::-1][:5] return [(label_encoder.inverse_transform([i])[0], probs[i]) for i in top_indices]