IlayMalinyak
sanity check
766ed77
import numpy as np
import librosa
import torch
import torch.nn as nn
# import pywt
from scipy import signal
def compute_cwt_power_spectrum(audio, sample_rate, num_freqs=128, f_min=20, f_max=None):
"""
Compute the power spectrum of continuous wavelet transform using Morlet wavelet.
Parameters:
audio: torch.Tensor
Input audio signal
sample_rate: int
Sampling rate of the audio
num_freqs: int
Number of frequency bins for the CWT
f_min: float
Minimum frequency to analyze
f_max: float or None
Maximum frequency to analyze (defaults to Nyquist frequency)
Returns:
torch.Tensor: CWT power spectrum
"""
# Convert to numpy
audio_np = audio.cpu().numpy()
# Set default f_max to Nyquist frequency if not specified
if f_max is None:
f_max = sample_rate // 2
# Generate frequency bins (logarithmically spaced)
frequencies = np.logspace(
np.log10(f_min),
np.log10(f_max),
num=num_freqs
)
# Compute the width of the wavelet (in samples)
widths = sample_rate / (2 * frequencies * np.pi)
# Compute CWT using Morlet wavelet
cwt = signal.cwt(
audio_np,
signal.morlet2,
widths,
w=5.0 # Width parameter of Morlet wavelet
)
# Compute power spectrum (magnitude squared)
power_spectrum = np.abs(cwt) ** 2
# Convert to torch tensor
power_spectrum_tensor = torch.FloatTensor(power_spectrum)
return power_spectrum_tensor
# def compute_wavelet_transform(audio, wavelet, decompos_level):
# """Compute wavelet decomposition of the audio signal."""
# # Convert to numpy and ensure 1D
# audio_np = audio.cpu().numpy()
#
# # Perform wavelet decomposition
# coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
#
# # Stack coefficients into a 2D array
# # First, pad all coefficient arrays to the same length
# max_len = max(len(c) for c in coeffs)
# padded_coeffs = []
# for coeff in coeffs:
# pad_len = max_len - len(coeff)
# if pad_len > 0:
# padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
# else:
# padded_coeff = coeff
# padded_coeffs.append(padded_coeff)
#
# # Stack into 2D array where each row is a different scale
# wavelet_features = np.stack(padded_coeffs)
#
# # Convert to tensor
# return torch.FloatTensor(wavelet_features)
def compute_melspectrogram(audio, sample_rate):
mel_spec = librosa.feature.melspectrogram(
y=audio.cpu().numpy(),
sr=sample_rate,
n_mels=128
)
return torch.FloatTensor(librosa.power_to_db(mel_spec))
def compute_mfcc(audio, sample_rate):
mfcc = librosa.feature.mfcc(
y=audio.cpu().numpy(),
sr=sample_rate,
n_mfcc=20
)
return torch.FloatTensor(mfcc)
def compute_chroma(audio, sample_rate):
chroma = librosa.feature.chroma_stft(
y=audio.cpu().numpy(),
sr=sample_rate
)
return torch.FloatTensor(chroma)
def compute_time_domain_features(audio, sample_rate, frame_length=2048, hop_length=128):
"""
Compute time-domain features from audio signal.
Returns a dictionary of features.
"""
# Convert to numpy
audio_np = audio.cpu().numpy()
# Initialize dictionary for features
features = {}
# 1. Zero Crossing Rate
zcr = librosa.feature.zero_crossing_rate(
y=audio_np,
frame_length=frame_length,
hop_length=hop_length
)
features['zcr'] = torch.Tensor([zcr.sum()])
# 2. Root Mean Square Energy
rms = librosa.feature.rms(
y=audio_np,
frame_length=frame_length,
hop_length=hop_length
)
features['rms_energy'] = torch.Tensor([rms.mean()])
# 3. Temporal Statistics
frames = librosa.util.frame(audio_np, frame_length=frame_length, hop_length=hop_length)
features['mean'] = torch.Tensor([np.mean(frames, axis=0).mean()])
features['std'] = torch.Tensor([np.std(frames, axis=0).mean()])
features['max'] = torch.Tensor([np.max(frames, axis=0).mean()])
# 4. Tempo and Beat Features
onset_env = librosa.onset.onset_strength(y=audio_np, sr=sample_rate)
tempo = librosa.beat.tempo(onset_envelope=onset_env, sr=sample_rate)
features['tempo'] = torch.Tensor(tempo)
# 5. Amplitude Envelope
envelope = np.abs(librosa.stft(audio_np, n_fft=frame_length, hop_length=hop_length))
features['envelope'] = torch.Tensor([np.mean(envelope, axis=0).mean()])
return features
def compute_frequency_domain_features(audio, sample_rate, n_fft=2048, hop_length=512):
"""
Compute frequency-domain features from audio signal.
Returns a dictionary of features.
"""
# Convert to numpy
audio_np = audio.cpu().numpy()
# Initialize dictionary for features
features = {}
# 1. Spectral Centroid
try:
spectral_centroids = librosa.feature.spectral_centroid(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
)
features['spectral_centroid'] = torch.FloatTensor([spectral_centroids.max()])
except Exception as e:
features['spectral_centroid'] = torch.FloatTensor([np.nan])
# 2. Spectral Rolloff
try:
spectral_rolloff = librosa.feature.spectral_rolloff(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
)
features['spectral_rolloff'] = torch.FloatTensor([spectral_rolloff.max()])
except Exception as e:
features['spectral_rolloff'] = torch.FloatTensor([np.nan])
# 3. Spectral Bandwidth
try:
spectral_bandwidth = librosa.feature.spectral_bandwidth(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length
)
features['spectral_bandwidth'] = torch.FloatTensor([spectral_bandwidth.max()])
except Exception as e:
features['spectral_bandwidth'] = torch.FloatTensor([np.nan])
# 4. Spectral Contrast
try:
spectral_contrast = librosa.feature.spectral_contrast(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
fmin=20, # Lower minimum frequency
n_bands=4, # Reduce number of bands
quantile=0.02
)
features['spectral_contrast'] = torch.FloatTensor([spectral_contrast.mean()])
except Exception as e:
features['spectral_contrast'] = torch.FloatTensor([np.nan])
# 5. Spectral Flatness
try:
spectral_flatness = librosa.feature.spectral_flatness(
y=audio_np,
n_fft=n_fft,
hop_length=hop_length
)
features['spectral_flatness'] = torch.FloatTensor([spectral_flatness.max()])
except Exception as e:
features['spectral_flatness'] = torch.FloatTensor([np.nan])
# 6. Spectral Flux
try:
stft = np.abs(librosa.stft(audio_np, n_fft=n_fft, hop_length=hop_length))
spectral_flux = np.diff(stft, axis=1)
spectral_flux = np.pad(spectral_flux, ((0, 0), (1, 0)), mode='constant')
features['spectral_flux'] = torch.FloatTensor([np.std(spectral_flux)])
except Exception as e:
features['spectral_flux'] = torch.FloatTensor([np.nan])
# 7. MFCCs (Mel-Frequency Cepstral Coefficients)
try:
mfccs = librosa.feature.mfcc(
y=audio_np,
sr=sample_rate,
n_mfcc=13, # Number of MFCCs to compute
n_fft=n_fft,
hop_length=hop_length
)
features['mfcc_mean'] = torch.FloatTensor([mfccs.mean()])
except Exception as e:
features['mfcc_mean'] = torch.FloatTensor([np.nan])
# 8. Chroma Features
try:
chroma = librosa.feature.chroma_stft(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length
)
features['chroma_mean'] = torch.FloatTensor([chroma.mean()])
except Exception as e:
features['chroma_mean'] = torch.FloatTensor([np.nan])
# 9. Spectral Kurtosis
try:
spectral_kurtosis = librosa.feature.spectral_kurtosis(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length
)
features['spectral_kurtosis'] = torch.FloatTensor([spectral_kurtosis.mean()])
except Exception as e:
features['spectral_kurtosis'] = torch.FloatTensor([np.nan])
# 10. Spectral Skewness
try:
spectral_skewness = librosa.feature.spectral_skewness(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length
)
features['spectral_skewness'] = torch.FloatTensor([spectral_skewness.mean()])
except Exception as e:
features['spectral_skewness'] = torch.FloatTensor([np.nan])
# 11. Spectral Slope
try:
spectral_slope = librosa.feature.spectral_slope(
y=audio_np,
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length
)
features['spectral_slope'] = torch.FloatTensor([spectral_slope.mean()])
except Exception as e:
features['spectral_slope'] = torch.FloatTensor([np.nan])
# 12. Tonnetz (Tonal Centroid Features)
try:
tonnetz = librosa.feature.tonnetz(
y=audio_np,
sr=sample_rate
)
features['tonnetz_mean'] = torch.FloatTensor([tonnetz.mean()])
except Exception as e:
features['tonnetz_mean'] = torch.FloatTensor([np.nan])
return features
def compute_all_features(audio, sample_rate, wavelet='db1', decompos_level=4):
"""
Compute all available features and return them in a dictionary.
"""
features = {}
# Basic transformations
# features['wavelet'] = compute_wavelet_transform(audio, wavelet, decompos_level)
# features['melspectrogram'] = compute_melspectrogram(audio, sample_rate)
# features['mfcc'] = compute_mfcc(audio, sample_rate)
# features['chroma'] = compute_chroma(audio, sample_rate)
# features['cwt_power'] = compute_cwt_power_spectrum(
# audio,
# sample_rate,
# num_freqs=128, # Same as mel bands for consistency
# f_min=20, # Standard lower frequency bound
# f_max=sample_rate // 2 # Nyquist frequency
# )
# Time domain features
# features['time_domain'] = compute_time_domain_features(audio, sample_rate)
# Frequency domain features
return compute_frequency_domain_features(audio, sample_rate)